Spaces:
Runtime error
Runtime error
| """Standalone pre/post-train evaluator for the drug-target-validation env. | |
| Loads a Hugging-Face causal-LM (or local checkpoint), runs ``--episodes`` | |
| fresh ``DrugTargetEnvironment`` rollouts using the same prompt-build / | |
| parser as ``training/training_script.py``, and writes one JSONL row per | |
| episode to ``--out``. | |
| Each row carries: | |
| * deterministic seed + scenario / target context, | |
| * the action sequence + per-step rewards the model produced, | |
| * whether the episode terminated with a submitted report, | |
| * whether the submitted go/no_go matches the hidden ``correct_decision``, | |
| * per-component reward totals derived from | |
| ``RewardBreakdown.to_dict()`` keys, summed over the trajectory. | |
| The evaluator is intentionally model-agnostic: any HF-compatible | |
| causal-LM works, including the SFT warm-start checkpoint that lives at | |
| ``runs/sft-warmstart`` and the GRPO output at ``runs/grpo-output``. A | |
| final summary line is printed for the orchestration log. | |
| CLI:: | |
| python -m training.evaluate \\ | |
| --model_name <hub-id-or-local-path> \\ | |
| --episodes 8 --max_steps 25 \\ | |
| --difficulty mixed \\ | |
| --tag pre_train \\ | |
| --out evidence/pre_eval.jsonl \\ | |
| --record_components | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import logging | |
| import sys | |
| import time | |
| from dataclasses import dataclass, field | |
| from pathlib import Path | |
| from typing import Any, Dict, List, Optional | |
| from models import ActionType, DrugTargetAction, ValidationObservation | |
| from server.hackathon_environment import DrugTargetEnvironment, MAX_STEPS | |
| from training.training_script import ( | |
| build_training_prompt, | |
| ensure_terminal_payload, | |
| parse_action_completion, | |
| ) | |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s") | |
| logger = logging.getLogger("training.evaluate") | |
| # Reward component column names mirror ``RewardBreakdown.to_dict()`` so | |
| # the JSONL is directly pivotable against the dashboard cards. | |
| _REWARD_COMPONENTS = ( | |
| "evidence_coverage", | |
| "decision_accuracy", | |
| "credit_efficiency", | |
| "reasoning_coherence", | |
| "novelty", | |
| "penalty", | |
| "shaping", | |
| "terminal", | |
| ) | |
| # A maximally-defensive fallback when the model emits unparseable text. | |
| # Mirrors ``training/training_script.OpenEnvReward.invalid_action_penalty`` | |
| # so the model is consistently penalised for non-JSON output. | |
| _INVALID_ACTION_PENALTY = -2.0 | |
| class EpisodeResult: | |
| """Single-episode evaluation record (becomes one JSONL row).""" | |
| episode: int | |
| scenario: Optional[str] | |
| target_gene: str | |
| indication: str | |
| tag: str | |
| seed: int | |
| n_steps: int | |
| cumulative_reward: float | |
| submitted: bool | |
| submitted_decision: Optional[str] | |
| correct_decision: Optional[str] | |
| decision_accuracy: float | |
| evidence_coverage: float | |
| credits_used: int | |
| credits_total: int | |
| action_sequence: List[str] | |
| step_rewards: List[float] | |
| invalid_actions: int | |
| reward_components_total: Dict[str, float] = field(default_factory=dict) | |
| def to_jsonl(self) -> str: | |
| return json.dumps({ | |
| "episode": self.episode, | |
| "scenario": self.scenario, | |
| "target_gene": self.target_gene, | |
| "indication": self.indication, | |
| "tag": self.tag, | |
| "seed": self.seed, | |
| "n_steps": self.n_steps, | |
| "cumulative_reward": round(self.cumulative_reward, 6), | |
| "submitted": self.submitted, | |
| "submitted_decision": self.submitted_decision, | |
| "correct_decision": self.correct_decision, | |
| "decision_accuracy": round(self.decision_accuracy, 6), | |
| "evidence_coverage": round(self.evidence_coverage, 6), | |
| "credits_used": self.credits_used, | |
| "credits_total": self.credits_total, | |
| "action_sequence": self.action_sequence, | |
| "step_rewards": [round(float(r), 6) for r in self.step_rewards], | |
| "invalid_actions": self.invalid_actions, | |
| "reward_components_total": { | |
| k: round(float(v), 6) | |
| for k, v in self.reward_components_total.items() | |
| }, | |
| }, ensure_ascii=False) | |
| def _parse_args(argv: Optional[List[str]] = None) -> argparse.Namespace: | |
| parser = argparse.ArgumentParser(description=__doc__) | |
| parser.add_argument("--model_name", required=True, | |
| help="HF hub id or local path (must be HF-compatible).") | |
| parser.add_argument("--episodes", type=int, default=8) | |
| parser.add_argument("--max_steps", type=int, default=25) | |
| parser.add_argument( | |
| "--difficulty", | |
| choices=["easy", "medium", "hard", "mixed"], | |
| default="mixed", | |
| help="Filter scenario library to this difficulty tier.", | |
| ) | |
| parser.add_argument( | |
| "--scenario_name", | |
| default=None, | |
| help="Optional scenario name override; default samples from the full library.", | |
| ) | |
| parser.add_argument("--tag", required=True, | |
| help="Free-form tag (e.g. pre_train, post_train) — copied into each row.") | |
| parser.add_argument("--out", required=True, | |
| help="Output JSONL path (parent dir is created).") | |
| parser.add_argument("--seed_base", type=int, default=10_000) | |
| parser.add_argument("--max_new_tokens", type=int, default=384) | |
| parser.add_argument("--temperature", type=float, default=0.7) | |
| parser.add_argument("--top_p", type=float, default=0.9) | |
| parser.add_argument( | |
| "--record_components", | |
| action="store_true", | |
| default=True, | |
| help=( | |
| "Record per-episode reward component totals " | |
| "(evidence_coverage, decision_accuracy, ...). On by default." | |
| ), | |
| ) | |
| parser.add_argument( | |
| "--no_record_components", | |
| dest="record_components", | |
| action="store_false", | |
| help="Skip per-component recording (writes empty dict).", | |
| ) | |
| parser.add_argument( | |
| "--trust_remote_code", | |
| action="store_true", | |
| help="Pass trust_remote_code=True to model/tokenizer loading.", | |
| ) | |
| parser.add_argument( | |
| "--no_sample", | |
| action="store_true", | |
| help="Use greedy decoding (do_sample=False).", | |
| ) | |
| return parser.parse_args(argv) | |
| def _filter_scenarios(env: DrugTargetEnvironment, difficulty: str) -> None: | |
| """Restrict the env's task generator to scenarios of the given tier.""" | |
| if difficulty == "mixed": | |
| return | |
| env._task_gen.scenarios = [ | |
| s for s in env._task_gen.scenarios if s.difficulty == difficulty | |
| ] | |
| if not env._task_gen.scenarios: | |
| raise SystemExit( | |
| f"No scenarios match difficulty={difficulty!r} " | |
| f"(allowed: easy, medium, hard, mixed)." | |
| ) | |
| def _load_model(model_name: str, *, trust_remote_code: bool): | |
| """Best-effort HF causal-LM load with sensible dtype + device pick.""" | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| use_cuda = torch.cuda.is_available() | |
| bf16 = bool(getattr(torch.cuda, "is_bf16_supported", lambda: False)()) if use_cuda else False | |
| dtype = torch.bfloat16 if bf16 else (torch.float16 if use_cuda else torch.float32) | |
| logger.info("loading tokenizer for %s", model_name) | |
| tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=trust_remote_code) | |
| if tokenizer.pad_token is None and tokenizer.eos_token is not None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| logger.info("loading model for %s (cuda=%s, dtype=%s)", model_name, use_cuda, dtype) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| trust_remote_code=trust_remote_code, | |
| torch_dtype=dtype, | |
| ) | |
| if use_cuda: | |
| model = model.to("cuda") | |
| model.eval() | |
| return tokenizer, model | |
| def _generate_action_text( | |
| *, | |
| tokenizer, | |
| model, | |
| prompt: str, | |
| max_new_tokens: int, | |
| temperature: float, | |
| top_p: float, | |
| do_sample: bool, | |
| ) -> str: | |
| """Run ``model.generate`` once and return only the new completion text.""" | |
| import torch | |
| inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048) | |
| inputs = {k: v.to(model.device) for k, v in inputs.items()} | |
| prompt_len = int(inputs["input_ids"].shape[1]) | |
| gen_kwargs = { | |
| "max_new_tokens": int(max_new_tokens), | |
| "do_sample": bool(do_sample), | |
| "pad_token_id": tokenizer.pad_token_id, | |
| } | |
| if do_sample: | |
| gen_kwargs["temperature"] = float(temperature) | |
| gen_kwargs["top_p"] = float(top_p) | |
| with torch.no_grad(): | |
| out_ids = model.generate(**inputs, **gen_kwargs) | |
| new_tokens = out_ids[0][prompt_len:] | |
| return tokenizer.decode(new_tokens, skip_special_tokens=True).strip() | |
| def _accumulate_components( | |
| totals: Dict[str, float], | |
| breakdown: Dict[str, float], | |
| ) -> None: | |
| """Sum step / terminal reward-breakdown values into the running totals. | |
| ``DrugTargetEnvironment`` already merges step + terminal breakdowns into | |
| a single dict per step (terminal keys carry a ``term_`` prefix). We | |
| therefore look up both shapes for each canonical component and add | |
| them together so the totals are *per-episode*, not per-step. | |
| """ | |
| for comp in _REWARD_COMPONENTS: | |
| v = breakdown.get(comp) | |
| if isinstance(v, (int, float)): | |
| totals[comp] = totals.get(comp, 0.0) + float(v) | |
| v = breakdown.get(f"term_{comp}") | |
| if isinstance(v, (int, float)): | |
| totals[comp] = totals.get(comp, 0.0) + float(v) | |
| def _final_episode_metrics( | |
| env: DrugTargetEnvironment, | |
| last_breakdown: Dict[str, float], | |
| ) -> Dict[str, Any]: | |
| """Pull terminal evidence_coverage + decision_accuracy + ground truth. | |
| ``env._latent`` is rebound on every step, so we read it *after* the | |
| final ``env.step`` call; the post-terminal breakdown carries the | |
| ``term_*`` keys that the reward computer wrote out. | |
| """ | |
| latent = env._latent | |
| correct_decision: Optional[str] = None | |
| submitted = False | |
| if latent is not None: | |
| correct_decision = latent.target.correct_decision | |
| submitted = bool(latent.progress.report_submitted) | |
| evidence_cov = float( | |
| last_breakdown.get("term_evidence_coverage") | |
| or last_breakdown.get("evidence_coverage") | |
| or 0.0 | |
| ) | |
| decision_acc = float( | |
| last_breakdown.get("term_decision_accuracy") | |
| or last_breakdown.get("decision_accuracy") | |
| or 0.0 | |
| ) | |
| return { | |
| "correct_decision": correct_decision, | |
| "submitted": submitted, | |
| "evidence_coverage": evidence_cov, | |
| "decision_accuracy": decision_acc, | |
| } | |
| def evaluate_episode( | |
| *, | |
| episode_idx: int, | |
| seed: int, | |
| args: argparse.Namespace, | |
| tokenizer, | |
| model, | |
| ) -> EpisodeResult: | |
| """Run a single evaluation episode end-to-end.""" | |
| env = DrugTargetEnvironment( | |
| scenario_name=args.scenario_name, | |
| domain_randomise=False, | |
| ) | |
| _filter_scenarios(env, args.difficulty) | |
| obs: ValidationObservation = env.reset(seed=seed) | |
| target_gene = obs.target_gene | |
| indication = obs.indication | |
| scenario_name = args.scenario_name or getattr(env._task, "scenario_name", None) \ | |
| or getattr(env._task, "target_gene", None) | |
| action_sequence: List[str] = [] | |
| step_rewards: List[float] = [] | |
| cumulative = 0.0 | |
| invalid = 0 | |
| submitted_decision: Optional[str] = None | |
| components_total: Dict[str, float] = {c: 0.0 for c in _REWARD_COMPONENTS} | |
| last_breakdown: Dict[str, float] = {} | |
| n_steps = 0 | |
| cap = min(int(args.max_steps), MAX_STEPS) | |
| for step_idx in range(cap): | |
| if obs.done: | |
| break | |
| prompt = build_training_prompt(obs) | |
| completion = _generate_action_text( | |
| tokenizer=tokenizer, | |
| model=model, | |
| prompt=prompt, | |
| max_new_tokens=args.max_new_tokens, | |
| temperature=args.temperature, | |
| top_p=args.top_p, | |
| do_sample=not args.no_sample, | |
| ) | |
| action = parse_action_completion(completion) | |
| if action is None: | |
| invalid += 1 | |
| cumulative += _INVALID_ACTION_PENALTY | |
| step_rewards.append(_INVALID_ACTION_PENALTY) | |
| action_sequence.append("__invalid__") | |
| n_steps += 1 | |
| # We deliberately do *not* call env.step on garbage output — | |
| # the env would raise; instead we charge the standard invalid | |
| # action penalty (mirrors the GRPO reward function) and try | |
| # again on the next step. This means n_steps reflects all | |
| # generation attempts, not just env-ticks. | |
| continue | |
| action = ensure_terminal_payload(obs, action) | |
| try: | |
| obs = env.step(action) | |
| except Exception as exc: | |
| logger.warning( | |
| "episode %d step %d: env.step raised (%s); episode terminated", | |
| episode_idx, step_idx, exc, | |
| ) | |
| cumulative += _INVALID_ACTION_PENALTY | |
| step_rewards.append(_INVALID_ACTION_PENALTY) | |
| action_sequence.append(action.action_type.value) | |
| n_steps += 1 | |
| invalid += 1 | |
| break | |
| action_sequence.append(action.action_type.value) | |
| step_reward = float(obs.reward or 0.0) | |
| cumulative += step_reward | |
| step_rewards.append(step_reward) | |
| last_breakdown = dict(obs.step_reward_breakdown or {}) | |
| if args.record_components: | |
| _accumulate_components(components_total, last_breakdown) | |
| if action.action_type == ActionType.SUBMIT_VALIDATION_REPORT: | |
| submitted_decision = action.final_decision | |
| n_steps += 1 | |
| final = _final_episode_metrics(env, last_breakdown) | |
| correct_decision = final["correct_decision"] | |
| submitted = final["submitted"] | |
| if not submitted: | |
| # Backstop the field with whatever the model last picked, so the | |
| # JSONL row is still inspectable. | |
| for at in reversed(action_sequence): | |
| if at == ActionType.SUBMIT_VALIDATION_REPORT.value: | |
| # action_sequence is parallel with step_rewards; this is | |
| # only reached when the env rejected the submission. | |
| break | |
| return EpisodeResult( | |
| episode=episode_idx, | |
| scenario=scenario_name, | |
| target_gene=target_gene, | |
| indication=indication, | |
| tag=args.tag, | |
| seed=seed, | |
| n_steps=n_steps, | |
| cumulative_reward=cumulative, | |
| submitted=submitted, | |
| submitted_decision=submitted_decision, | |
| correct_decision=correct_decision, | |
| decision_accuracy=float(final["decision_accuracy"]), | |
| evidence_coverage=float(final["evidence_coverage"]), | |
| credits_used=int(getattr(obs, "credits_total", 0) - getattr(obs, "credits_remaining", 0)), | |
| credits_total=int(getattr(obs, "credits_total", 0)), | |
| action_sequence=action_sequence, | |
| step_rewards=step_rewards, | |
| invalid_actions=invalid, | |
| reward_components_total=components_total if args.record_components else {}, | |
| ) | |
| def _summary(results: List[EpisodeResult]) -> Dict[str, float]: | |
| """Aggregate stats for the human-readable log line.""" | |
| n = len(results) | |
| if n == 0: | |
| return {"n": 0} | |
| rewards = [r.cumulative_reward for r in results] | |
| submitted = [1.0 if r.submitted else 0.0 for r in results] | |
| decision_match = [ | |
| 1.0 if (r.submitted and r.submitted_decision == r.correct_decision) else 0.0 | |
| for r in results | |
| ] | |
| success = [ | |
| 1.0 if ( | |
| r.submitted | |
| and r.submitted_decision == r.correct_decision | |
| and r.cumulative_reward > 0 | |
| ) else 0.0 | |
| for r in results | |
| ] | |
| return { | |
| "n": n, | |
| "mean_reward": sum(rewards) / n, | |
| "median_reward": sorted(rewards)[n // 2], | |
| "submission_rate": sum(submitted) / n, | |
| "decision_match_rate": sum(decision_match) / n, | |
| "success_rate": sum(success) / n, | |
| "mean_evidence_coverage": sum(r.evidence_coverage for r in results) / n, | |
| "mean_decision_accuracy": sum(r.decision_accuracy for r in results) / n, | |
| } | |
| def main(argv: Optional[List[str]] = None) -> int: | |
| args = _parse_args(argv) | |
| out_path = Path(args.out) | |
| out_path.parent.mkdir(parents=True, exist_ok=True) | |
| t0 = time.time() | |
| tokenizer, model = _load_model( | |
| args.model_name, | |
| trust_remote_code=args.trust_remote_code, | |
| ) | |
| load_s = time.time() - t0 | |
| logger.info( | |
| "starting eval: model=%s episodes=%d max_steps=%d tag=%s difficulty=%s", | |
| args.model_name, args.episodes, args.max_steps, args.tag, args.difficulty, | |
| ) | |
| results: List[EpisodeResult] = [] | |
| eval_t0 = time.time() | |
| with open(out_path, "w", encoding="utf-8") as f: | |
| for i in range(int(args.episodes)): | |
| seed = int(args.seed_base) + i | |
| try: | |
| result = evaluate_episode( | |
| episode_idx=i, | |
| seed=seed, | |
| args=args, | |
| tokenizer=tokenizer, | |
| model=model, | |
| ) | |
| except Exception as exc: | |
| logger.exception("episode %d failed: %s", i, exc) | |
| result = EpisodeResult( | |
| episode=i, | |
| scenario=args.scenario_name, | |
| target_gene="UNKNOWN", | |
| indication="unspecified", | |
| tag=args.tag, | |
| seed=seed, | |
| n_steps=0, | |
| cumulative_reward=_INVALID_ACTION_PENALTY, | |
| submitted=False, | |
| submitted_decision=None, | |
| correct_decision=None, | |
| decision_accuracy=0.0, | |
| evidence_coverage=0.0, | |
| credits_used=0, | |
| credits_total=0, | |
| action_sequence=[], | |
| step_rewards=[], | |
| invalid_actions=1, | |
| reward_components_total={c: 0.0 for c in _REWARD_COMPONENTS}, | |
| ) | |
| f.write(result.to_jsonl() + "\n") | |
| f.flush() | |
| results.append(result) | |
| logger.info( | |
| "[%s ep=%d] reward=%+.3f steps=%d submitted=%s decision=%s/%s " | |
| "coverage=%.2f decision_acc=%.2f", | |
| args.tag, i, result.cumulative_reward, result.n_steps, | |
| result.submitted, result.submitted_decision, | |
| result.correct_decision, result.evidence_coverage, | |
| result.decision_accuracy, | |
| ) | |
| eval_s = time.time() - eval_t0 | |
| summary = _summary(results) | |
| summary.update({ | |
| "load_duration_s": round(load_s, 2), | |
| "eval_duration_s": round(eval_s, 2), | |
| "model_name": args.model_name, | |
| "tag": args.tag, | |
| "out": str(out_path), | |
| }) | |
| logger.info("[eval-summary tag=%s] %s", args.tag, json.dumps(summary, indent=2)) | |
| return 0 | |
| if __name__ == "__main__": | |
| sys.exit(main()) | |