Spaces:
Runtime error
Runtime error
| """Run the drug-target-validation environment with Qwen as the agent. | |
| This script provides the baseline inference loop used by the dashboard. | |
| It instantiates :class:`DrugTargetEnvironment`, formats each | |
| :class:`ValidationObservation` into a prompt, calls the local Qwen model | |
| to obtain a structured :class:`DrugTargetAction`, and writes the running | |
| state to the dashboard JSON file. | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import os | |
| import re | |
| import time | |
| from pathlib import Path | |
| from typing import Any, Dict, List, Optional | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline | |
| from models import ( | |
| ActionType, | |
| DrugTargetAction, | |
| ValidationObservation, | |
| build_agent_observation_context, | |
| build_agent_system_prompt, | |
| ) | |
| from server.hackathon_environment import DrugTargetEnvironment | |
| DASHBOARD_STATE_PATH = Path(__file__).parent / "_dashboard_state.json" | |
| DASHBOARD_CMD_PATH = Path(__file__).parent / "_dashboard_cmd.json" | |
| USE_PIPELINE = os.getenv("RUN_AGENT_USE_PIPELINE", "0").strip().lower() not in { | |
| "0", "false", "off", | |
| } | |
| def _parse_thinking_flag() -> bool: | |
| import sys | |
| if "--no-thinking" in sys.argv: | |
| return False | |
| if "--thinking" in sys.argv: | |
| return True | |
| return os.getenv("RUN_AGENT_ENABLE_THINKING", "1").strip().lower() not in { | |
| "0", "false", "off", | |
| } | |
| ENABLE_THINKING = _parse_thinking_flag() | |
| MODEL_ID = "Qwen/Qwen2.5-3B-Instruct" | |
| MAX_EPISODE_STEPS = int(os.getenv("RUN_AGENT_MAX_EPISODE_STEPS", "20")) | |
| PIPELINE_TASK = "text-generation" | |
| ACTION_TYPES = [a.value for a in ActionType] | |
| ACTION_TYPE_ALIASES: Dict[str, str] = { | |
| # Expression / omics | |
| "expression_lookup": ActionType.QUERY_EXPRESSION.value, | |
| "tissue_expression": ActionType.QUERY_EXPRESSION.value, | |
| "gtex_query": ActionType.QUERY_EXPRESSION.value, | |
| "de_analysis": ActionType.DIFFERENTIAL_EXPRESSION.value, | |
| "differential_analysis": ActionType.DIFFERENTIAL_EXPRESSION.value, | |
| "pathway": ActionType.PATHWAY_ENRICHMENT.value, | |
| "coexpression": ActionType.COEXPRESSION_NETWORK.value, | |
| # Protein | |
| "structure_lookup": ActionType.PROTEIN_STRUCTURE_LOOKUP.value, | |
| "alphafold_lookup": ActionType.PROTEIN_STRUCTURE_LOOKUP.value, | |
| "binding_site": ActionType.BINDING_SITE_ANALYSIS.value, | |
| "ppi": ActionType.PROTEIN_INTERACTION_NETWORK.value, | |
| "interaction_network": ActionType.PROTEIN_INTERACTION_NETWORK.value, | |
| "druggability": ActionType.DRUGGABILITY_SCREEN.value, | |
| # Clinical | |
| "trial_lookup": ActionType.CLINICAL_TRIAL_LOOKUP.value, | |
| "clinical_lookup": ActionType.CLINICAL_TRIAL_LOOKUP.value, | |
| "toxicity": ActionType.TOXICITY_PANEL.value, | |
| "tox_panel": ActionType.TOXICITY_PANEL.value, | |
| "off_target": ActionType.OFF_TARGET_SCREEN.value, | |
| "selectivity": ActionType.OFF_TARGET_SCREEN.value, | |
| "stratification": ActionType.PATIENT_STRATIFICATION.value, | |
| "patient_subset": ActionType.PATIENT_STRATIFICATION.value, | |
| # Literature | |
| "literature": ActionType.LITERATURE_SEARCH.value, | |
| "pubmed": ActionType.LITERATURE_SEARCH.value, | |
| "synthesis": ActionType.EVIDENCE_SYNTHESIS.value, | |
| "competitor": ActionType.COMPETITOR_LANDSCAPE.value, | |
| # Experimental | |
| "in_vitro": ActionType.IN_VITRO_ASSAY.value, | |
| "cell_assay": ActionType.IN_VITRO_ASSAY.value, | |
| "in_vivo": ActionType.IN_VIVO_MODEL.value, | |
| "mouse_model": ActionType.IN_VIVO_MODEL.value, | |
| "crispr": ActionType.CRISPR_KNOCKOUT.value, | |
| "biomarker": ActionType.BIOMARKER_CORRELATION.value, | |
| # Meta | |
| "red_flag": ActionType.FLAG_RED_FLAG.value, | |
| "expert_review": ActionType.REQUEST_EXPERT_REVIEW.value, | |
| "submit_report": ActionType.SUBMIT_VALIDATION_REPORT.value, | |
| "final_report": ActionType.SUBMIT_VALIDATION_REPORT.value, | |
| "submit": ActionType.SUBMIT_VALIDATION_REPORT.value, | |
| } | |
| SYSTEM_PROMPT = build_agent_system_prompt() | |
| # Sensible default investigation order used as a fallback / hint for the | |
| # dashboard prompt rather than a hard pipeline. | |
| STANDARD_PIPELINE_ORDER: List[ActionType] = [ | |
| ActionType.QUERY_EXPRESSION, | |
| ActionType.DRUGGABILITY_SCREEN, | |
| ActionType.OFF_TARGET_SCREEN, | |
| ActionType.TOXICITY_PANEL, | |
| ActionType.CLINICAL_TRIAL_LOOKUP, | |
| ActionType.LITERATURE_SEARCH, | |
| ActionType.SUBMIT_VALIDATION_REPORT, | |
| ] | |
| MODEL_RESPONSE_PREVIEW_CHARS = int( | |
| os.getenv("RUN_AGENT_MODEL_RESPONSE_PREVIEW_CHARS", "240") | |
| ) | |
| def compact_preview(value: Any, max_chars: int = 160) -> str: | |
| try: | |
| text = json.dumps(value, ensure_ascii=True, sort_keys=True) | |
| except TypeError: | |
| text = str(value) | |
| text = re.sub(r"\s+", " ", text).strip() | |
| if len(text) <= max_chars: | |
| return text | |
| return text[: max_chars - 3] + "..." | |
| def format_observation(obs: ValidationObservation) -> str: | |
| parts = [ | |
| f"TARGET: {obs.target_gene} | INDICATION: {obs.indication}", | |
| f"DISEASE CONTEXT: {obs.disease_context}", | |
| f"Step: {obs.step_index} | Credits: " | |
| f"{obs.credits_remaining}/{obs.credits_total}", | |
| ] | |
| context = build_agent_observation_context(obs, max_tools=5) | |
| if context: | |
| parts.append(context) | |
| if obs.pipeline_history: | |
| last5 = obs.pipeline_history[-5:] | |
| parts.append("Recent history:") | |
| for h in last5: | |
| tag = "OK" if h.get("success") else "FAIL" | |
| line = ( | |
| f" [{tag}] {h.get('action_type')}: " | |
| f"{str(h.get('output_summary', ''))[:80]}" | |
| ) | |
| parts.append(line) | |
| completed = { | |
| h.get("action_type") for h in obs.pipeline_history if h.get("success") | |
| } | |
| if completed: | |
| parts.append( | |
| f"Completed actions (try not to repeat): " | |
| f"{', '.join(sorted(map(str, completed)))}" | |
| ) | |
| remaining = [ | |
| a.value for a in STANDARD_PIPELINE_ORDER | |
| if a.value not in completed | |
| ] | |
| if remaining: | |
| parts.append(f"Suggested next actions (pick one): {', '.join(remaining)}") | |
| dossier = obs.dossier | |
| if dossier.flagged_red_flags: | |
| parts.append(f"Red flags so far: {dossier.flagged_red_flags[:5]}") | |
| if dossier.expression_findings: | |
| parts.append( | |
| f"Expression findings: " | |
| f"{compact_preview(dossier.expression_findings, 200)}" | |
| ) | |
| if dossier.protein_findings: | |
| parts.append( | |
| f"Protein findings: " | |
| f"{compact_preview(dossier.protein_findings, 200)}" | |
| ) | |
| if dossier.safety_findings: | |
| parts.append( | |
| f"Safety findings: " | |
| f"{compact_preview(dossier.safety_findings, 200)}" | |
| ) | |
| if dossier.clinical_findings: | |
| parts.append( | |
| f"Clinical findings: " | |
| f"{compact_preview(dossier.clinical_findings, 200)}" | |
| ) | |
| if dossier.experimental_results: | |
| parts.append( | |
| f"Experimental results: " | |
| f"{compact_preview(dossier.experimental_results, 200)}" | |
| ) | |
| if obs.latest_output and obs.latest_output.data: | |
| parts.append( | |
| f"Latest output: {compact_preview(obs.latest_output.data, 200)}" | |
| ) | |
| if obs.rule_violations: | |
| parts.append(f"VIOLATIONS: {obs.rule_violations}") | |
| parts.append( | |
| 'Output ONLY a single JSON object with these exact keys, no ' | |
| 'comments, no extra text:\n' | |
| '{"action_type": "<one of the available actions>", ' | |
| '"parameters": {}, "reasoning": "<why this step now>", ' | |
| '"final_decision": null, "confidence": null}\n' | |
| 'For the final submit_validation_report, set "final_decision" to ' | |
| '"go" or "no_go" and "confidence" to a number in [0, 1].' | |
| ) | |
| return "\n".join(parts) | |
| # ── JSON parsing helpers (kept compatible with prior small-LLM quirks) ─ | |
| def _repair_truncated_json(text: str) -> Optional[str]: | |
| s = text.strip() | |
| if not s.startswith("{"): | |
| return None | |
| s = re.sub(r',\s*"[^"\n]*$', '', s) | |
| s = re.sub(r',\s*"[^"\n]*"\s*:\s*$', '', s) | |
| in_string = False | |
| escape = False | |
| for ch in s: | |
| if escape: | |
| escape = False | |
| continue | |
| if ch == "\\": | |
| escape = True | |
| continue | |
| if ch == '"': | |
| in_string = not in_string | |
| if in_string: | |
| s += '"' | |
| open_braces = s.count("{") - s.count("}") | |
| open_brackets = s.count("[") - s.count("]") | |
| s += "]" * max(0, open_brackets) | |
| s += "}" * max(0, open_braces) | |
| try: | |
| obj = json.loads(s) | |
| if isinstance(obj, dict): | |
| return s | |
| except json.JSONDecodeError: | |
| pass | |
| s = re.sub(r',\s*([}\]])', r'\1', s) | |
| try: | |
| obj = json.loads(s) | |
| if isinstance(obj, dict): | |
| return s | |
| except json.JSONDecodeError: | |
| pass | |
| return None | |
| def _strip_js_comments(text: str) -> str: | |
| text = re.sub(r'//[^\n]*', '', text) | |
| text = re.sub(r'/\*.*?\*/', '', text, flags=re.DOTALL) | |
| return text | |
| def _normalize_jsonish_text(text: str) -> str: | |
| text = _strip_js_comments(text) | |
| text = re.sub(r'(?<=:\s)\bNone\b', 'null', text) | |
| text = re.sub(r'(?<=:\s)\bTrue\b', 'true', text) | |
| text = re.sub(r'(?<=:\s)\bFalse\b', 'false', text) | |
| text = re.sub(r'"([^"\n]+?):"\s*,', r'"\1": "",', text) | |
| return text | |
| def extract_json_object(text: str) -> Optional[Dict[str, Any]]: | |
| stripped = _normalize_jsonish_text(text).strip() | |
| if stripped.startswith('"') and stripped.endswith('"'): | |
| try: | |
| unwrapped = json.loads(stripped) | |
| except json.JSONDecodeError: | |
| unwrapped = None | |
| if isinstance(unwrapped, str): | |
| stripped = _normalize_jsonish_text(unwrapped).strip() | |
| fence_prefix = "```" | |
| if stripped.startswith(fence_prefix) and stripped.endswith(fence_prefix): | |
| lines = stripped.splitlines() | |
| if len(lines) >= 3: | |
| stripped = "\n".join(lines[1:-1]).strip() | |
| candidates: List[str] = [stripped] | |
| start = stripped.find("{") | |
| while start != -1: | |
| depth = 0 | |
| for idx in range(start, len(stripped)): | |
| char = stripped[idx] | |
| if char == "{": | |
| depth += 1 | |
| elif char == "}": | |
| depth -= 1 | |
| if depth == 0: | |
| candidates.append(stripped[start:idx + 1]) | |
| break | |
| start = stripped.find("{", start + 1) | |
| repaired = None | |
| first_brace = stripped.find("{") | |
| if first_brace != -1: | |
| repaired = _repair_truncated_json(stripped[first_brace:]) | |
| if repaired is not None: | |
| candidates.append(repaired) | |
| candidates.sort(key=len, reverse=True) | |
| for candidate in candidates: | |
| try: | |
| parsed = json.loads(candidate) | |
| except json.JSONDecodeError: | |
| continue | |
| if isinstance(parsed, dict): | |
| return parsed | |
| return None | |
| def _edit_distance(a: str, b: str) -> int: | |
| if len(a) < len(b): | |
| return _edit_distance(b, a) | |
| if not b: | |
| return len(a) | |
| prev = list(range(len(b) + 1)) | |
| for i, ca in enumerate(a): | |
| curr = [i + 1] | |
| for j, cb in enumerate(b): | |
| curr.append(min(prev[j + 1] + 1, curr[j] + 1, prev[j] + (ca != cb))) | |
| prev = curr | |
| return prev[-1] | |
| def get_payload_value(payload: Dict[str, Any], *names: str) -> Any: | |
| for name in names: | |
| if name in payload: | |
| return payload[name] | |
| lowered = {str(k).lower(): v for k, v in payload.items()} | |
| for name in names: | |
| if name.lower() in lowered: | |
| return lowered[name.lower()] | |
| for key, value in lowered.items(): | |
| for name in names: | |
| threshold = max(2, len(name) // 3) | |
| if _edit_distance(key, name.lower()) <= threshold: | |
| return value | |
| return None | |
| def normalize_optional_string(value: Any) -> Optional[str]: | |
| if value is None or isinstance(value, bool): | |
| return None | |
| if isinstance(value, str): | |
| value = value.strip() | |
| return value or None | |
| if isinstance(value, (int, float)): | |
| return str(value) | |
| return compact_preview(value, 80) | |
| def normalize_action_type(raw: Any) -> Optional[str]: | |
| if not isinstance(raw, str): | |
| return None | |
| candidate = raw.strip().lower() | |
| if candidate in ACTION_TYPES: | |
| return candidate | |
| if candidate in ACTION_TYPE_ALIASES: | |
| return ACTION_TYPE_ALIASES[candidate] | |
| candidate = re.sub(r"[^a-z0-9]+", "_", candidate).strip("_") | |
| if candidate in ACTION_TYPES: | |
| return candidate | |
| if candidate in ACTION_TYPE_ALIASES: | |
| return ACTION_TYPE_ALIASES[candidate] | |
| heuristics = [ | |
| (("expression",), ActionType.QUERY_EXPRESSION.value), | |
| (("differential",), ActionType.DIFFERENTIAL_EXPRESSION.value), | |
| (("pathway",), ActionType.PATHWAY_ENRICHMENT.value), | |
| (("coexpression",), ActionType.COEXPRESSION_NETWORK.value), | |
| (("structure",), ActionType.PROTEIN_STRUCTURE_LOOKUP.value), | |
| (("binding", "site"), ActionType.BINDING_SITE_ANALYSIS.value), | |
| (("interaction",), ActionType.PROTEIN_INTERACTION_NETWORK.value), | |
| (("druggab",), ActionType.DRUGGABILITY_SCREEN.value), | |
| (("clinical",), ActionType.CLINICAL_TRIAL_LOOKUP.value), | |
| (("toxic",), ActionType.TOXICITY_PANEL.value), | |
| (("off",), ActionType.OFF_TARGET_SCREEN.value), | |
| (("strat",), ActionType.PATIENT_STRATIFICATION.value), | |
| (("literature",), ActionType.LITERATURE_SEARCH.value), | |
| (("synthes",), ActionType.EVIDENCE_SYNTHESIS.value), | |
| (("competitor",), ActionType.COMPETITOR_LANDSCAPE.value), | |
| (("vitro",), ActionType.IN_VITRO_ASSAY.value), | |
| (("vivo",), ActionType.IN_VIVO_MODEL.value), | |
| (("crispr",), ActionType.CRISPR_KNOCKOUT.value), | |
| (("biomarker",), ActionType.BIOMARKER_CORRELATION.value), | |
| (("red", "flag"), ActionType.FLAG_RED_FLAG.value), | |
| (("review",), ActionType.REQUEST_EXPERT_REVIEW.value), | |
| (("submit",), ActionType.SUBMIT_VALIDATION_REPORT.value), | |
| (("report",), ActionType.SUBMIT_VALIDATION_REPORT.value), | |
| ] | |
| for fragments, normalized in heuristics: | |
| if all(fragment in candidate for fragment in fragments): | |
| return normalized | |
| return None | |
| def parse_action(text: str) -> Optional[DrugTargetAction]: | |
| d = extract_json_object(text) | |
| if d is None: | |
| return None | |
| action_type = normalize_action_type(get_payload_value(d, "action_type")) | |
| if action_type is None: | |
| return None | |
| parameters = get_payload_value(d, "parameters", "params") or {} | |
| if not isinstance(parameters, dict): | |
| parameters = {} | |
| raw_conf = get_payload_value(d, "confidence") | |
| confidence: Optional[float] | |
| if raw_conf is None: | |
| confidence = None | |
| else: | |
| try: | |
| confidence = float(raw_conf) | |
| confidence = max(0.0, min(1.0, confidence)) | |
| except (TypeError, ValueError): | |
| confidence = None | |
| reasoning = get_payload_value( | |
| d, "reasoning", "rationale", "justification", "reason" | |
| ) | |
| if reasoning is not None and not isinstance(reasoning, str): | |
| reasoning = compact_preview(reasoning, 200) | |
| reasoning = reasoning or "" | |
| final_decision = normalize_optional_string( | |
| get_payload_value(d, "final_decision", "decision", "go_no_go") | |
| ) | |
| if final_decision is not None: | |
| final_decision = final_decision.lower().replace("-", "_") | |
| if final_decision not in {"go", "no_go"}: | |
| final_decision = None | |
| return DrugTargetAction( | |
| action_type=ActionType(action_type), | |
| parameters=parameters, | |
| reasoning=reasoning, | |
| final_decision=final_decision, | |
| confidence=confidence, | |
| ) | |
| def ensure_terminal_payload(action: DrugTargetAction) -> DrugTargetAction: | |
| """Make sure a SUBMIT_VALIDATION_REPORT carries a decision + confidence.""" | |
| if action.action_type != ActionType.SUBMIT_VALIDATION_REPORT: | |
| return action | |
| decision = action.final_decision or "no_go" | |
| confidence = action.confidence if action.confidence is not None else 0.5 | |
| return action.model_copy(update={ | |
| "final_decision": decision, | |
| "confidence": float(max(0.0, min(1.0, confidence))), | |
| }) | |
| # ── Dashboard state writer ──────────────────────────────────────────── | |
| def write_dashboard_state( | |
| env: DrugTargetEnvironment, | |
| obs: ValidationObservation, | |
| *, | |
| step: int, | |
| cumulative_reward: float, | |
| model_response: str = "", | |
| model_thinking: str = "", | |
| action: Optional[DrugTargetAction] = None, | |
| gen_time: float = 0.0, | |
| episode_done: bool = False, | |
| ) -> None: | |
| latent = env._latent | |
| snapshot: Dict[str, Any] = { | |
| "timestamp": time.time(), | |
| "step": step, | |
| "episode_done": episode_done, | |
| "cumulative_reward": cumulative_reward, | |
| "gen_time_s": round(gen_time, 2), | |
| "model_response_raw": model_response[:600], | |
| "model_thinking": model_thinking[:800], | |
| "thinking_enabled": ENABLE_THINKING, | |
| } | |
| snapshot["task"] = { | |
| "target_gene": obs.target_gene, | |
| "indication": obs.indication, | |
| "disease_context": obs.disease_context, | |
| "credits_total": obs.credits_total, | |
| } | |
| snapshot["resources"] = { | |
| "credits_used": obs.credits_total - obs.credits_remaining, | |
| "credits_remaining": obs.credits_remaining, | |
| "credits_total": obs.credits_total, | |
| } | |
| snapshot["pipeline_history"] = [ | |
| { | |
| "step_index": h.get("step_index"), | |
| "action_type": h.get("action_type"), | |
| "output_summary": str(h.get("output_summary", ""))[:120], | |
| "success": h.get("success"), | |
| "quality_score": round(h.get("quality_score", 0.0), 3), | |
| "credit_cost": h.get("credit_cost", 0), | |
| } | |
| for h in obs.pipeline_history | |
| ] | |
| if action: | |
| snapshot["current_action"] = { | |
| "action_type": action.action_type.value, | |
| "parameters": action.parameters, | |
| "reasoning": action.reasoning, | |
| "final_decision": action.final_decision, | |
| "confidence": action.confidence, | |
| } | |
| if obs.latest_output: | |
| lo = obs.latest_output | |
| snapshot["latest_output"] = { | |
| "summary": lo.summary, | |
| "success": lo.success, | |
| "quality_score": round(lo.quality_score, 3), | |
| "uncertainty": round(lo.uncertainty, 3), | |
| "warnings": lo.warnings, | |
| "data_preview": compact_preview(lo.data, 300) if lo.data else None, | |
| } | |
| snapshot["dossier"] = obs.dossier.model_dump() | |
| snapshot["rule_violations"] = obs.rule_violations | |
| snapshot["reward_breakdown"] = { | |
| k: round(v, 4) for k, v in obs.step_reward_breakdown.items() | |
| } | |
| if latent: | |
| t = latent.target | |
| snapshot["latent"] = { | |
| "target_profile": { | |
| "expression_level": t.expression_level, | |
| "tissue_specificity": round(t.tissue_specificity, 3), | |
| "disease_overexpression": round(t.disease_overexpression, 3), | |
| "druggability_score": round(t.druggability_score, 3), | |
| "binding_pocket_quality": t.binding_pocket_quality, | |
| "has_known_ligands": t.has_known_ligands, | |
| "allosteric_site_available": t.allosteric_site_available, | |
| "selectivity_ratio": round(t.selectivity_ratio, 3), | |
| "off_target_count": t.off_target_count, | |
| "off_target_genes": t.off_target_genes, | |
| "toxicity_profile": t.toxicity_profile, | |
| "toxicity_tissues": t.toxicity_tissues, | |
| "clinical_precedent": t.clinical_precedent, | |
| "clinical_stage_reached": t.clinical_stage_reached, | |
| "competitor_programs": t.competitor_programs, | |
| "true_viability_score": round(t.true_viability_score, 3), | |
| "correct_decision": t.correct_decision, | |
| "key_evidence_dimensions": t.key_evidence_dimensions, | |
| "misleading_signals": t.misleading_signals, | |
| }, | |
| "data_quality": latent.data_quality.model_dump(), | |
| "credits": { | |
| "credits_used": latent.credits.credits_used, | |
| "credits_total": latent.credits.credits_total, | |
| "credits_remaining": latent.credits.credits_remaining, | |
| }, | |
| "progress": latent.progress.model_dump(), | |
| "action_call_counts": latent.action_call_counts, | |
| } | |
| try: | |
| DASHBOARD_STATE_PATH.write_text( | |
| json.dumps(snapshot, indent=2, default=str), encoding="utf-8", | |
| ) | |
| except Exception: | |
| pass | |
| def log(msg: str) -> None: | |
| print(msg, flush=True) | |
| def build_observation_prompt(obs: ValidationObservation) -> str: | |
| return format_observation(obs) | |
| def run_with_pipeline(pipe, prompt: str) -> str: | |
| try: | |
| _pipe_max = 2048 if ENABLE_THINKING else 300 | |
| result = pipe(prompt, max_new_tokens=_pipe_max, return_full_text=False) | |
| except Exception: | |
| return "" | |
| if isinstance(result, list) and result: | |
| result = result[0] | |
| if isinstance(result, dict): | |
| text = result.get("generated_text") or result.get("text") or result.get("answer") | |
| elif isinstance(result, str): | |
| text = result | |
| else: | |
| text = "" | |
| return text.strip() if isinstance(text, str) else "" | |
| def resolve_torch_runtime() -> Dict[str, Any]: | |
| use_cuda = torch.cuda.is_available() | |
| bf16 = bool(getattr(torch.cuda, "is_bf16_supported", lambda: False)()) if use_cuda else False | |
| dtype = torch.bfloat16 if bf16 else ( | |
| torch.float16 if use_cuda else torch.float32 | |
| ) | |
| return { | |
| "use_cuda": use_cuda, | |
| "device": "cuda:0" if use_cuda else "cpu", | |
| "dtype": dtype, | |
| "device_map": "auto" if use_cuda else None, | |
| "device_name": torch.cuda.get_device_name(0) if use_cuda else "cpu", | |
| } | |
| def main(): | |
| tokenizer = None | |
| model = None | |
| eos_ids: List[int] = [] | |
| active_pipeline = None | |
| runtime = resolve_torch_runtime() | |
| log( | |
| f"Using local model runtime: device={runtime['device']} " | |
| f"name={runtime['device_name']} dtype={runtime['dtype']}" | |
| ) | |
| if USE_PIPELINE: | |
| log(f"Loading pipeline ({PIPELINE_TASK}) for {MODEL_ID} ...") | |
| try: | |
| active_pipeline = pipeline( | |
| PIPELINE_TASK, | |
| model=MODEL_ID, | |
| trust_remote_code=True, | |
| dtype=runtime["dtype"], | |
| device=0 if runtime["use_cuda"] else -1, | |
| ) | |
| log("Pipeline loaded.") | |
| except Exception as exc: | |
| log(f"Pipeline load failed ({exc}), falling back to tokenizer+model.") | |
| if active_pipeline is None: | |
| log(f"Loading tokenizer for {MODEL_ID} ...") | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| MODEL_ID, trust_remote_code=True, | |
| ) | |
| log("Tokenizer loaded. Loading model (this may download files on first run) ...") | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_ID, | |
| dtype=runtime["dtype"], | |
| device_map=runtime["device_map"], | |
| trust_remote_code=True, | |
| ) | |
| log(f"Model loaded. Device: {model.device}") | |
| if tokenizer.eos_token_id is not None: | |
| eos_ids.append(tokenizer.eos_token_id) | |
| extra = tokenizer.convert_tokens_to_ids(["<|im_end|>", "<|endoftext|>"]) | |
| for tid in extra: | |
| if isinstance(tid, int) and tid not in eos_ids: | |
| eos_ids.append(tid) | |
| log(f"EOS token ids: {eos_ids}") | |
| def check_dashboard_command() -> Optional[Dict[str, Any]]: | |
| try: | |
| raw = DASHBOARD_CMD_PATH.read_text(encoding="utf-8") | |
| try: | |
| DASHBOARD_CMD_PATH.unlink(missing_ok=True) | |
| except OSError: | |
| pass | |
| return json.loads(raw) | |
| except (FileNotFoundError, json.JSONDecodeError): | |
| return None | |
| def run_episode(scenario_name: Optional[str] = None): | |
| env = DrugTargetEnvironment(scenario_name=scenario_name) | |
| obs = env.reset() | |
| log("\n" + "=" * 70) | |
| log( | |
| f"TARGET: {obs.target_gene} | INDICATION: {obs.indication} | " | |
| f"Credits: {obs.credits_total}" | |
| ) | |
| if ENABLE_THINKING: | |
| log("Reasoning mode: ENABLED") | |
| log("=" * 70) | |
| cumulative_reward = 0.0 | |
| write_dashboard_state(env, obs, step=0, cumulative_reward=0.0) | |
| for step in range(MAX_EPISODE_STEPS): | |
| cmd = check_dashboard_command() | |
| if cmd and cmd.get("action") == "restart": | |
| log("\n[DASHBOARD] Restart requested — ending episode early.") | |
| break | |
| user_msg = build_observation_prompt(obs) | |
| messages = [ | |
| {"role": "system", "content": SYSTEM_PROMPT}, | |
| {"role": "user", "content": user_msg}, | |
| ] | |
| if active_pipeline is not None: | |
| prompt = f"{SYSTEM_PROMPT}\n\n{user_msg}" | |
| else: | |
| try: | |
| prompt = tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True, | |
| enable_thinking=ENABLE_THINKING, | |
| ) | |
| except TypeError: | |
| prompt = tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True, | |
| ) | |
| t0 = time.time() | |
| if active_pipeline is not None: | |
| response = run_with_pipeline(active_pipeline, prompt) | |
| if not response: | |
| response = format_observation(obs) | |
| else: | |
| assert tokenizer is not None and model is not None | |
| inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
| n_input = inputs["input_ids"].shape[1] | |
| max_new = 2048 if ENABLE_THINKING else 300 | |
| with torch.no_grad(): | |
| output_ids = model.generate( | |
| **inputs, | |
| max_new_tokens=max_new, | |
| do_sample=True, | |
| temperature=0.7, | |
| top_p=0.8, | |
| top_k=20, | |
| repetition_penalty=1.3, | |
| eos_token_id=eos_ids if eos_ids else None, | |
| ) | |
| new_tokens = output_ids[0][n_input:] | |
| response = tokenizer.decode(new_tokens, skip_special_tokens=True).strip() | |
| gen_time = time.time() - t0 | |
| thinking = "" | |
| if ENABLE_THINKING: | |
| think_match = re.search( | |
| r"<think>(.*?)</think>", response, re.DOTALL | |
| ) | |
| if think_match: | |
| thinking = think_match.group(1).strip() | |
| response = response[think_match.end():].strip() | |
| elif response.startswith("<think>"): | |
| parts = response.split("</think>", 1) | |
| if len(parts) == 2: | |
| thinking = parts[0].replace("<think>", "").strip() | |
| response = parts[1].strip() | |
| is_last_step = (step == MAX_EPISODE_STEPS - 1) | |
| action = parse_action(response) | |
| if action is None: | |
| if is_last_step: | |
| log( | |
| "\n [!] Parse failed on final step — forcing " | |
| "submit_validation_report." | |
| ) | |
| action = DrugTargetAction( | |
| action_type=ActionType.SUBMIT_VALIDATION_REPORT, | |
| reasoning="forced terminal report", | |
| final_decision="no_go", | |
| confidence=0.5, | |
| ) | |
| else: | |
| log(f"\n [!] Parse failed, skipping step. Raw: {response[:150]}") | |
| continue | |
| if is_last_step and action.action_type != ActionType.SUBMIT_VALIDATION_REPORT: | |
| log( | |
| f"\n [!] Final step — overriding {action.action_type.value} " | |
| f"with submit_validation_report." | |
| ) | |
| action = DrugTargetAction( | |
| action_type=ActionType.SUBMIT_VALIDATION_REPORT, | |
| reasoning="forced terminal report", | |
| final_decision="no_go", | |
| confidence=action.confidence or 0.5, | |
| ) | |
| action = ensure_terminal_payload(action) | |
| log(f"\nStep {step + 1}: {action.action_type.value} ({gen_time:.1f}s)") | |
| if thinking: | |
| log(f" Thinking: {thinking[:200]}") | |
| if action.reasoning: | |
| log(f" Reasoning: {action.reasoning}") | |
| else: | |
| log(" Reasoning: [model did not provide one]") | |
| if action.parameters: | |
| log(f" Parameters: {compact_preview(action.parameters, 200)}") | |
| obs = env.step(action) | |
| if obs.latest_output: | |
| lo = obs.latest_output | |
| status = "OK" if lo.success else "FAIL" | |
| log(f" [{status}] {lo.summary}") | |
| if lo.warnings: | |
| log(f" Warnings: {lo.warnings}") | |
| step_reward = obs.reward | |
| cumulative_reward += step_reward | |
| log(f" Reward: {step_reward:+.3f} (cum: {cumulative_reward:+.3f})") | |
| log( | |
| f" Credits remaining: {obs.credits_remaining}" | |
| f"/{obs.credits_total}" | |
| ) | |
| write_dashboard_state( | |
| env, obs, | |
| step=step + 1, | |
| cumulative_reward=cumulative_reward, | |
| model_response=response, | |
| model_thinking=thinking, | |
| action=action, | |
| gen_time=gen_time, | |
| episode_done=obs.done, | |
| ) | |
| if obs.rule_violations: | |
| log(f" Violations: {obs.rule_violations}") | |
| if obs.done: | |
| break | |
| log(f"\n{'=' * 70}") | |
| log("EPISODE COMPLETE" if obs.done else f"MAX STEPS ({MAX_EPISODE_STEPS})") | |
| log(f" Steps: {obs.step_index}") | |
| log(f" Total reward: {cumulative_reward:+.3f}") | |
| log( | |
| f" Credits used: {obs.credits_total - obs.credits_remaining}" | |
| f"/{obs.credits_total}" | |
| ) | |
| log("=" * 70) | |
| try: | |
| DASHBOARD_CMD_PATH.unlink(missing_ok=True) | |
| except OSError: | |
| pass | |
| run_episode() | |
| while True: | |
| log("\nWaiting for dashboard command (restart / new task) ...") | |
| while True: | |
| cmd = check_dashboard_command() | |
| if cmd: | |
| break | |
| time.sleep(1.0) | |
| action_type = cmd.get("action", "restart") | |
| if action_type == "quit": | |
| log("Quit requested.") | |
| break | |
| scenario = cmd.get("scenario_name") | |
| log(f"\n[DASHBOARD] {action_type} — scenario={scenario}") | |
| run_episode(scenario_name=scenario) | |
| if __name__ == "__main__": | |
| main() | |