"""Observation building and scoring mixin for MolForge.""" from __future__ import annotations from copy import deepcopy from typing import Any, Dict, List, Mapping from .shared import ( DEFAULT_TOOL_COSTS, EDITABLE_SLOTS, ROLE_MESSAGE_TYPES, ROLE_PERMISSIONS, SCENARIOS, SLOT_ORDER, compute_objective_score, enumerate_candidate_edits, evaluate_constraint_margins, evaluate_constraints, literature_hints, molecule_to_smiles, oracle_backend_status, ) try: from ..models import ConstraintCheck, MolForgeObservation, MoleculeSlot, RoleObservation except ImportError: from models import ConstraintCheck, MolForgeObservation, MoleculeSlot, RoleObservation class MolForgeViewMixin: """Observation, report-card, and grader methods.""" def _build_observation( self, *, reward: float, done: bool, reward_components: List, ) -> MolForgeObservation: current_signature = self._molecule_signature() current_assays = [ reading for reading in self._known_assays if reading.molecule_signature == current_signature ] visible_metrics = { "budget_fraction_remaining": round( self._state.remaining_budget / max(self._scenario.oracle_budget, 1), 4 ), "current_molecule_assay_count": float(len(current_assays)), } for property_name in ["potency", "toxicity", "synth", "novelty"]: estimate = self._current_property_estimate(property_name, current_signature) if estimate is not None: visible_metrics[property_name] = estimate constraint_status = self._build_visible_constraints(current_signature) metadata: Dict[str, Any] = { "task_index": self._reset_index % len(SCENARIOS), "oracle_budget_costs": deepcopy(DEFAULT_TOOL_COSTS), "history_length": len(self._history), "trace_tail": [entry["summary"] for entry in self._history[-3:]], "current_smiles": molecule_to_smiles(self._molecule), "oracle_backend": oracle_backend_status(), "candidate_edits": [ {"slot": slot, "fragment": fragment} for slot, fragment in list(enumerate_candidate_edits(self._molecule))[:8] ], "literature_hints": literature_hints(self._molecule), "target_shift_active": self._target_shift_active(), "public_role_metrics": { role: { "messages_sent": metrics["messages_sent"], "correct_messages": metrics["correct_messages"], } for role, metrics in self._role_metrics.items() }, } if done: metadata["terminal_grader_scores"] = self._grade_all() return MolForgeObservation( scenario_id=self._scenario.scenario_id, difficulty=self._scenario.difficulty, state_label=self._state.state_label, state_path=list(self._state_path), coordination_mode=self._scenario.coordination_mode, # type: ignore[arg-type] enabled_roles=list(self._scenario.enabled_roles), task_brief=self._scenario.task_brief, target_name=self._scenario.target_name, current_molecule=current_signature, molecule_slots=[ MoleculeSlot(slot=slot, fragment=self._molecule[slot], editable=True) for slot in SLOT_ORDER ], editable_slots=list(EDITABLE_SLOTS), step_index=self._state.step_count, max_steps=self._scenario.max_steps, remaining_budget=self._state.remaining_budget, budget_used=self._state.budget_used, max_budget=self._scenario.oracle_budget, known_assays=deepcopy(self._known_assays), role_observations=self._build_role_observations(current_signature), message_log=[message.model_dump() for message in self._message_log[-8:]], governance=deepcopy(self._last_governance), last_transition_summary=self._last_summary, visible_metrics=visible_metrics, constraint_status=constraint_status, reward_breakdown=reward_components, allowed_actions=[ "Lead Chemist: edit, submit, restart, defer", "Assay Planner: run_assay", "Messages: proposal, approval, objection, risk_flag, assay_request, rejection", ], report_card=self._report_card, metadata=metadata, done=done, reward=reward, ) def _build_visible_constraints(self, molecule_signature: str) -> List[ConstraintCheck]: checks: List[ConstraintCheck] = [] for name, threshold in self._scenario.hard_constraints.items(): property_name = "toxicity" if name == "toxicity_max" else name.split("_")[0] estimate = self._current_property_estimate(property_name, molecule_signature) relation = "<=" if name.endswith("_max") else ">=" if estimate is None: checks.append( ConstraintCheck( name=name, target=f"{relation} {threshold:.2f}", satisfied=None, actual=None, evidence_status="unknown", ) ) continue satisfied = estimate <= threshold if name.endswith("_max") else estimate >= threshold checks.append( ConstraintCheck( name=name, target=f"{relation} {threshold:.2f}", satisfied=satisfied, actual=round(estimate, 4), evidence_status="known", ) ) return checks def _build_role_observations(self, molecule_signature: str) -> List[RoleObservation]: current_assays = [ reading.model_dump() for reading in self._known_assays if reading.molecule_signature == molecule_signature ] evidence_gaps = [ prop for prop in ["potency", "toxicity", "synth"] if self._current_property_estimate(prop, molecule_signature) is None ] edit_history = [ entry["action"] for entry in self._history if entry["action"].get("action_type") == "edit" ][-4:] return [ RoleObservation( role="lead_chemist", local_objective="Propose high-value scaffold edits and decide when the team should submit.", permissions=ROLE_PERMISSIONS["lead_chemist"], observation={ "molecule_slots": deepcopy(self._molecule), "edit_history": edit_history, "visible_assays": current_assays, "candidate_edits": [ {"slot": slot, "fragment": fragment} for slot, fragment in list(enumerate_candidate_edits(self._molecule))[:8] ], "open_questions": evidence_gaps, }, ), RoleObservation( role="toxicologist", local_objective="Protect against safety regressions and unsafe submissions.", permissions=ROLE_MESSAGE_TYPES["toxicologist"], observation={ "toxicity_readouts": [ reading for reading in current_assays if reading["property_name"] == "toxicity" ], "hard_threshold": self._scenario.hard_constraints.get("toxicity_max"), "safety_alerts": self._safety_alerts(), "risk_history": [ message.model_dump() for message in self._message_log if message.sender == "toxicologist" ][-4:], }, ), RoleObservation( role="assay_planner", local_objective="Allocate assay budget where the expected information gain is highest.", permissions=ROLE_PERMISSIONS["assay_planner"] + ROLE_MESSAGE_TYPES["assay_planner"], observation={ "budget_ledger": { "remaining_budget": self._state.remaining_budget, "budget_used": self._state.budget_used, "max_budget": self._state.max_budget, }, "tool_costs": deepcopy(DEFAULT_TOOL_COSTS), "tool_usage_history": deepcopy(self._assay_runs), "evidence_gaps": evidence_gaps, "estimated_information_value": { tool_name: round(self._estimate_information_gain(tool_name), 4) for tool_name in self._scenario.enabled_tools }, }, ), RoleObservation( role="process_chemist", local_objective="Guard tractability and synthetic feasibility before the team commits.", permissions=ROLE_MESSAGE_TYPES["process_chemist"], observation={ "synth_readouts": [ reading for reading in current_assays if reading["property_name"] == "synth" ], "route_warnings": self._route_warnings(), "feasibility_flags": { "heavy_hinge": self._molecule["hinge"] == "quinazoline", "reactive_warhead": self._molecule["warhead"] == "vinyl_sulfonamide", "lipophilic_tail": self._molecule["back_pocket"] == "trifluoromethyl", }, }, ), ] def _grade_all(self) -> Dict[str, float]: properties = self._true_properties() constraints = evaluate_constraints(properties, self._scenario) constraint_margins = evaluate_constraint_margins(properties, self._scenario) constraint_margin_score = sum(constraint_margins.values()) / max(len(constraint_margins), 1) constraint_fraction = sum(1.0 for passed, _ in constraints.values() if passed) / max(len(constraints), 1) submitted = self._state.submitted coordination_score = self._coordination_score() evidence_score = self._evidence_score() budget_score = self._open_unit_interval( self._state.remaining_budget / max(self._scenario.oracle_budget, 1), ) progress_score = self._grade_progress( candidate_score=compute_objective_score(properties, self._scenario), constraint_margin_score=constraint_margin_score, constraint_fraction=constraint_fraction, evidence_score=evidence_score, coordination_score=coordination_score, budget_score=budget_score, ) submission_score = self._grade_submission(properties) if submitted else 0.0 final_score = self._grade_final( submission_score=submission_score, progress_score=progress_score, submitted=submitted, constraint_fraction=constraint_fraction, evidence_score=evidence_score, ) return { "final_score": final_score, "potency_score": self._open_unit_interval(properties["potency"]), "safety_score": self._open_unit_interval(1.0 - properties["toxicity"]), "synth_score": self._open_unit_interval(properties["synth"]), "novelty_score": self._open_unit_interval(properties["novelty"]), "candidate_score": self._open_unit_interval(compute_objective_score(properties, self._scenario)), "constraint_score": self._open_unit_interval( sum(1.0 for passed, _ in constraints.values() if passed) / max(len(constraints), 1), ), "constraint_margin_score": self._open_unit_interval(constraint_margin_score), "budget_score": budget_score, "submitted_score": 1.0 if submitted else 0.0, "submission_score": submission_score, "progress_score": progress_score, "coordination_score": self._open_unit_interval(coordination_score), "evidence_score": self._open_unit_interval(evidence_score), } def _grade_progress( self, *, candidate_score: float, constraint_margin_score: float, constraint_fraction: float, evidence_score: float, coordination_score: float, budget_score: float, ) -> float: """Score scientific progress even when no formal submission happened.""" progress = ( 0.45 * candidate_score + 0.35 * constraint_margin_score + 0.10 * evidence_score + 0.05 * coordination_score + 0.05 * budget_score ) repeated_assays = sum(max(0, runs - 1) for runs in self._assay_runs.values()) policy_vetoes = sum( 1 for entry in self._history if entry.get("governance", {}).get("status") == "policy_veto" ) progress -= min(0.20, 0.04 * repeated_assays) progress -= min(0.20, 0.05 * policy_vetoes) if constraint_fraction < 1.0: progress = min(progress, 0.25 + 0.25 * constraint_fraction) if not self._state.submitted and evidence_score < 0.99: progress = min(progress, 0.45) if self._scenario.trap_penalty and not self._restart_used: progress = min(progress, 0.30) if self._state.submitted: progress += 0.05 return self._open_unit_interval(progress) def _grade_final( self, *, submission_score: float, progress_score: float, submitted: bool, constraint_fraction: float, evidence_score: float, ) -> float: """Single conservative scalar for RL/evaluation headline reporting.""" if submitted: return self._open_unit_interval(submission_score) score = 0.35 * progress_score if constraint_fraction < 1.0: score = min(score, 0.05 + 0.10 * constraint_fraction) if evidence_score < 0.99: score = min(score, 0.15) if self._scenario.trap_penalty and not self._restart_used: score = min(score, 0.08) return self._open_unit_interval(score) def _coordination_score(self) -> float: expected_messages = 0 for entry in self._history: action = entry.get("action", {}) if action.get("action_type") == "defer": continue expected_messages += 1 + len(entry.get("governance", {}).get("required_roles", [])) if expected_messages == 0: return self._open_unit_interval(0.0) total_correct = sum(metrics["correct_messages"] for metrics in self._role_metrics.values()) return self._open_unit_interval(min(total_correct, expected_messages) / expected_messages) def _grade_submission(self, properties: Mapping[str, float]) -> float: base = compute_objective_score(properties, self._scenario) constraint_margins = evaluate_constraint_margins(properties, self._scenario) constraint_margin_score = sum(constraint_margins.values()) / max(len(constraint_margins), 1) constraints = evaluate_constraints(properties, self._scenario) constraint_fraction = sum(1.0 for passed, _ in constraints.values() if passed) / max(len(constraints), 1) submission_score = ( 0.60 * base + 0.20 * constraint_margin_score + 0.10 * self._coordination_score() + 0.10 * self._evidence_score() ) evidence_score = self._evidence_score() if evidence_score >= 0.99 and constraint_fraction >= 1.0 and base >= self._scenario.baseline_to_beat: budget_efficiency = self._state.remaining_budget / max(self._scenario.oracle_budget, 1) submission_score += 0.05 * max(0.0, budget_efficiency) if evidence_score < 1.0: submission_score = min(submission_score, 0.25 + 0.25 * evidence_score) if constraint_fraction < 1.0: submission_score = min(submission_score, 0.20 + 0.50 * constraint_margin_score) if base < self._scenario.baseline_to_beat: submission_score = min(submission_score, 0.45) return self._open_unit_interval(submission_score) def _evidence_score(self) -> float: current_signature = self._molecule_signature() required = ["potency", "toxicity"] if "synth_min" in self._scenario.hard_constraints: required.append("synth") available = sum( 1 for prop in required if self._current_property_estimate(prop, current_signature) is not None ) score = available / max(len(required), 1) if self._scenario.target_shift_step and self._target_shift_active(): has_post_shift_potency = any( entry["step"] >= self._scenario.target_shift_step and entry["molecule"] == current_signature and any(result["property_name"] == "potency" for result in entry["results"]) for entry in self._oracle_log ) score = min(score, 1.0 if has_post_shift_potency else 0.5) return score def _build_report_card(self, *, submitted: bool) -> str: properties = self._true_properties() grader_scores = self._grade_all() constraints = evaluate_constraints(properties, self._scenario) lines = [ f"Scenario: {self._scenario.scenario_id} ({self._scenario.difficulty})", f"Final molecule: {self._molecule_signature()}", f"Potency: {properties['potency']:.3f}", f"Toxicity: {properties['toxicity']:.3f}", f"Synthesizability: {properties['synth']:.3f}", f"Novelty: {properties['novelty']:.3f}", f"Final score: {grader_scores['final_score']:.3f}", f"Candidate scientific score: {grader_scores['candidate_score']:.3f}", f"Constraint margin score: {grader_scores['constraint_margin_score']:.3f}", f"Submission grader: {grader_scores['submission_score']:.3f}", f"Progress score: {grader_scores['progress_score']:.3f}", f"Coordination score: {grader_scores['coordination_score']:.3f}", f"Evidence score: {grader_scores['evidence_score']:.3f}", "Constraints:", ] for name, (passed, threshold) in constraints.items(): metric_name = "toxicity" if name == "toxicity_max" else name.split("_")[0] lines.append( f"- {name}: {'pass' if passed else 'fail'} (actual={properties[metric_name]:.3f}, threshold={threshold:.3f})" ) lines.append( f"Messages sent: {self._state.message_count}, objections raised: {self._state.objection_count}, oracle calls: {self._state.oracle_call_count}" ) if self._scenario.target_shift_step and self._target_shift_active(): lines.append("Target mutation triggered during this episode.") if self._restart_used: lines.append("Agent used restart_from_new_scaffold to escape the original trap series.") if not submitted: lines.append("Episode terminated without a formal submit action.") return "\n".join(lines)