drugenv-trainer / run_agent.py
anugrahteesdollar's picture
fix: include requirements-train.txt + tests (glob bug)
ad12dda verified
"""Run the drug-target-validation environment with Qwen as the agent.
This script provides the baseline inference loop used by the dashboard.
It instantiates :class:`DrugTargetEnvironment`, formats each
:class:`ValidationObservation` into a prompt, calls the local Qwen model
to obtain a structured :class:`DrugTargetAction`, and writes the running
state to the dashboard JSON file.
"""
from __future__ import annotations
import json
import os
import re
import time
from pathlib import Path
from typing import Any, Dict, List, Optional
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from models import (
ActionType,
DrugTargetAction,
ValidationObservation,
build_agent_observation_context,
build_agent_system_prompt,
)
from server.hackathon_environment import DrugTargetEnvironment
DASHBOARD_STATE_PATH = Path(__file__).parent / "_dashboard_state.json"
DASHBOARD_CMD_PATH = Path(__file__).parent / "_dashboard_cmd.json"
USE_PIPELINE = os.getenv("RUN_AGENT_USE_PIPELINE", "0").strip().lower() not in {
"0", "false", "off",
}
def _parse_thinking_flag() -> bool:
import sys
if "--no-thinking" in sys.argv:
return False
if "--thinking" in sys.argv:
return True
return os.getenv("RUN_AGENT_ENABLE_THINKING", "1").strip().lower() not in {
"0", "false", "off",
}
ENABLE_THINKING = _parse_thinking_flag()
MODEL_ID = "Qwen/Qwen2.5-3B-Instruct"
MAX_EPISODE_STEPS = int(os.getenv("RUN_AGENT_MAX_EPISODE_STEPS", "20"))
PIPELINE_TASK = "text-generation"
ACTION_TYPES = [a.value for a in ActionType]
ACTION_TYPE_ALIASES: Dict[str, str] = {
# Expression / omics
"expression_lookup": ActionType.QUERY_EXPRESSION.value,
"tissue_expression": ActionType.QUERY_EXPRESSION.value,
"gtex_query": ActionType.QUERY_EXPRESSION.value,
"de_analysis": ActionType.DIFFERENTIAL_EXPRESSION.value,
"differential_analysis": ActionType.DIFFERENTIAL_EXPRESSION.value,
"pathway": ActionType.PATHWAY_ENRICHMENT.value,
"coexpression": ActionType.COEXPRESSION_NETWORK.value,
# Protein
"structure_lookup": ActionType.PROTEIN_STRUCTURE_LOOKUP.value,
"alphafold_lookup": ActionType.PROTEIN_STRUCTURE_LOOKUP.value,
"binding_site": ActionType.BINDING_SITE_ANALYSIS.value,
"ppi": ActionType.PROTEIN_INTERACTION_NETWORK.value,
"interaction_network": ActionType.PROTEIN_INTERACTION_NETWORK.value,
"druggability": ActionType.DRUGGABILITY_SCREEN.value,
# Clinical
"trial_lookup": ActionType.CLINICAL_TRIAL_LOOKUP.value,
"clinical_lookup": ActionType.CLINICAL_TRIAL_LOOKUP.value,
"toxicity": ActionType.TOXICITY_PANEL.value,
"tox_panel": ActionType.TOXICITY_PANEL.value,
"off_target": ActionType.OFF_TARGET_SCREEN.value,
"selectivity": ActionType.OFF_TARGET_SCREEN.value,
"stratification": ActionType.PATIENT_STRATIFICATION.value,
"patient_subset": ActionType.PATIENT_STRATIFICATION.value,
# Literature
"literature": ActionType.LITERATURE_SEARCH.value,
"pubmed": ActionType.LITERATURE_SEARCH.value,
"synthesis": ActionType.EVIDENCE_SYNTHESIS.value,
"competitor": ActionType.COMPETITOR_LANDSCAPE.value,
# Experimental
"in_vitro": ActionType.IN_VITRO_ASSAY.value,
"cell_assay": ActionType.IN_VITRO_ASSAY.value,
"in_vivo": ActionType.IN_VIVO_MODEL.value,
"mouse_model": ActionType.IN_VIVO_MODEL.value,
"crispr": ActionType.CRISPR_KNOCKOUT.value,
"biomarker": ActionType.BIOMARKER_CORRELATION.value,
# Meta
"red_flag": ActionType.FLAG_RED_FLAG.value,
"expert_review": ActionType.REQUEST_EXPERT_REVIEW.value,
"submit_report": ActionType.SUBMIT_VALIDATION_REPORT.value,
"final_report": ActionType.SUBMIT_VALIDATION_REPORT.value,
"submit": ActionType.SUBMIT_VALIDATION_REPORT.value,
}
SYSTEM_PROMPT = build_agent_system_prompt()
# Sensible default investigation order used as a fallback / hint for the
# dashboard prompt rather than a hard pipeline.
STANDARD_PIPELINE_ORDER: List[ActionType] = [
ActionType.QUERY_EXPRESSION,
ActionType.DRUGGABILITY_SCREEN,
ActionType.OFF_TARGET_SCREEN,
ActionType.TOXICITY_PANEL,
ActionType.CLINICAL_TRIAL_LOOKUP,
ActionType.LITERATURE_SEARCH,
ActionType.SUBMIT_VALIDATION_REPORT,
]
MODEL_RESPONSE_PREVIEW_CHARS = int(
os.getenv("RUN_AGENT_MODEL_RESPONSE_PREVIEW_CHARS", "240")
)
def compact_preview(value: Any, max_chars: int = 160) -> str:
try:
text = json.dumps(value, ensure_ascii=True, sort_keys=True)
except TypeError:
text = str(value)
text = re.sub(r"\s+", " ", text).strip()
if len(text) <= max_chars:
return text
return text[: max_chars - 3] + "..."
def format_observation(obs: ValidationObservation) -> str:
parts = [
f"TARGET: {obs.target_gene} | INDICATION: {obs.indication}",
f"DISEASE CONTEXT: {obs.disease_context}",
f"Step: {obs.step_index} | Credits: "
f"{obs.credits_remaining}/{obs.credits_total}",
]
context = build_agent_observation_context(obs, max_tools=5)
if context:
parts.append(context)
if obs.pipeline_history:
last5 = obs.pipeline_history[-5:]
parts.append("Recent history:")
for h in last5:
tag = "OK" if h.get("success") else "FAIL"
line = (
f" [{tag}] {h.get('action_type')}: "
f"{str(h.get('output_summary', ''))[:80]}"
)
parts.append(line)
completed = {
h.get("action_type") for h in obs.pipeline_history if h.get("success")
}
if completed:
parts.append(
f"Completed actions (try not to repeat): "
f"{', '.join(sorted(map(str, completed)))}"
)
remaining = [
a.value for a in STANDARD_PIPELINE_ORDER
if a.value not in completed
]
if remaining:
parts.append(f"Suggested next actions (pick one): {', '.join(remaining)}")
dossier = obs.dossier
if dossier.flagged_red_flags:
parts.append(f"Red flags so far: {dossier.flagged_red_flags[:5]}")
if dossier.expression_findings:
parts.append(
f"Expression findings: "
f"{compact_preview(dossier.expression_findings, 200)}"
)
if dossier.protein_findings:
parts.append(
f"Protein findings: "
f"{compact_preview(dossier.protein_findings, 200)}"
)
if dossier.safety_findings:
parts.append(
f"Safety findings: "
f"{compact_preview(dossier.safety_findings, 200)}"
)
if dossier.clinical_findings:
parts.append(
f"Clinical findings: "
f"{compact_preview(dossier.clinical_findings, 200)}"
)
if dossier.experimental_results:
parts.append(
f"Experimental results: "
f"{compact_preview(dossier.experimental_results, 200)}"
)
if obs.latest_output and obs.latest_output.data:
parts.append(
f"Latest output: {compact_preview(obs.latest_output.data, 200)}"
)
if obs.rule_violations:
parts.append(f"VIOLATIONS: {obs.rule_violations}")
parts.append(
'Output ONLY a single JSON object with these exact keys, no '
'comments, no extra text:\n'
'{"action_type": "<one of the available actions>", '
'"parameters": {}, "reasoning": "<why this step now>", '
'"final_decision": null, "confidence": null}\n'
'For the final submit_validation_report, set "final_decision" to '
'"go" or "no_go" and "confidence" to a number in [0, 1].'
)
return "\n".join(parts)
# ── JSON parsing helpers (kept compatible with prior small-LLM quirks) ─
def _repair_truncated_json(text: str) -> Optional[str]:
s = text.strip()
if not s.startswith("{"):
return None
s = re.sub(r',\s*"[^"\n]*$', '', s)
s = re.sub(r',\s*"[^"\n]*"\s*:\s*$', '', s)
in_string = False
escape = False
for ch in s:
if escape:
escape = False
continue
if ch == "\\":
escape = True
continue
if ch == '"':
in_string = not in_string
if in_string:
s += '"'
open_braces = s.count("{") - s.count("}")
open_brackets = s.count("[") - s.count("]")
s += "]" * max(0, open_brackets)
s += "}" * max(0, open_braces)
try:
obj = json.loads(s)
if isinstance(obj, dict):
return s
except json.JSONDecodeError:
pass
s = re.sub(r',\s*([}\]])', r'\1', s)
try:
obj = json.loads(s)
if isinstance(obj, dict):
return s
except json.JSONDecodeError:
pass
return None
def _strip_js_comments(text: str) -> str:
text = re.sub(r'//[^\n]*', '', text)
text = re.sub(r'/\*.*?\*/', '', text, flags=re.DOTALL)
return text
def _normalize_jsonish_text(text: str) -> str:
text = _strip_js_comments(text)
text = re.sub(r'(?<=:\s)\bNone\b', 'null', text)
text = re.sub(r'(?<=:\s)\bTrue\b', 'true', text)
text = re.sub(r'(?<=:\s)\bFalse\b', 'false', text)
text = re.sub(r'"([^"\n]+?):"\s*,', r'"\1": "",', text)
return text
def extract_json_object(text: str) -> Optional[Dict[str, Any]]:
stripped = _normalize_jsonish_text(text).strip()
if stripped.startswith('"') and stripped.endswith('"'):
try:
unwrapped = json.loads(stripped)
except json.JSONDecodeError:
unwrapped = None
if isinstance(unwrapped, str):
stripped = _normalize_jsonish_text(unwrapped).strip()
fence_prefix = "```"
if stripped.startswith(fence_prefix) and stripped.endswith(fence_prefix):
lines = stripped.splitlines()
if len(lines) >= 3:
stripped = "\n".join(lines[1:-1]).strip()
candidates: List[str] = [stripped]
start = stripped.find("{")
while start != -1:
depth = 0
for idx in range(start, len(stripped)):
char = stripped[idx]
if char == "{":
depth += 1
elif char == "}":
depth -= 1
if depth == 0:
candidates.append(stripped[start:idx + 1])
break
start = stripped.find("{", start + 1)
repaired = None
first_brace = stripped.find("{")
if first_brace != -1:
repaired = _repair_truncated_json(stripped[first_brace:])
if repaired is not None:
candidates.append(repaired)
candidates.sort(key=len, reverse=True)
for candidate in candidates:
try:
parsed = json.loads(candidate)
except json.JSONDecodeError:
continue
if isinstance(parsed, dict):
return parsed
return None
def _edit_distance(a: str, b: str) -> int:
if len(a) < len(b):
return _edit_distance(b, a)
if not b:
return len(a)
prev = list(range(len(b) + 1))
for i, ca in enumerate(a):
curr = [i + 1]
for j, cb in enumerate(b):
curr.append(min(prev[j + 1] + 1, curr[j] + 1, prev[j] + (ca != cb)))
prev = curr
return prev[-1]
def get_payload_value(payload: Dict[str, Any], *names: str) -> Any:
for name in names:
if name in payload:
return payload[name]
lowered = {str(k).lower(): v for k, v in payload.items()}
for name in names:
if name.lower() in lowered:
return lowered[name.lower()]
for key, value in lowered.items():
for name in names:
threshold = max(2, len(name) // 3)
if _edit_distance(key, name.lower()) <= threshold:
return value
return None
def normalize_optional_string(value: Any) -> Optional[str]:
if value is None or isinstance(value, bool):
return None
if isinstance(value, str):
value = value.strip()
return value or None
if isinstance(value, (int, float)):
return str(value)
return compact_preview(value, 80)
def normalize_action_type(raw: Any) -> Optional[str]:
if not isinstance(raw, str):
return None
candidate = raw.strip().lower()
if candidate in ACTION_TYPES:
return candidate
if candidate in ACTION_TYPE_ALIASES:
return ACTION_TYPE_ALIASES[candidate]
candidate = re.sub(r"[^a-z0-9]+", "_", candidate).strip("_")
if candidate in ACTION_TYPES:
return candidate
if candidate in ACTION_TYPE_ALIASES:
return ACTION_TYPE_ALIASES[candidate]
heuristics = [
(("expression",), ActionType.QUERY_EXPRESSION.value),
(("differential",), ActionType.DIFFERENTIAL_EXPRESSION.value),
(("pathway",), ActionType.PATHWAY_ENRICHMENT.value),
(("coexpression",), ActionType.COEXPRESSION_NETWORK.value),
(("structure",), ActionType.PROTEIN_STRUCTURE_LOOKUP.value),
(("binding", "site"), ActionType.BINDING_SITE_ANALYSIS.value),
(("interaction",), ActionType.PROTEIN_INTERACTION_NETWORK.value),
(("druggab",), ActionType.DRUGGABILITY_SCREEN.value),
(("clinical",), ActionType.CLINICAL_TRIAL_LOOKUP.value),
(("toxic",), ActionType.TOXICITY_PANEL.value),
(("off",), ActionType.OFF_TARGET_SCREEN.value),
(("strat",), ActionType.PATIENT_STRATIFICATION.value),
(("literature",), ActionType.LITERATURE_SEARCH.value),
(("synthes",), ActionType.EVIDENCE_SYNTHESIS.value),
(("competitor",), ActionType.COMPETITOR_LANDSCAPE.value),
(("vitro",), ActionType.IN_VITRO_ASSAY.value),
(("vivo",), ActionType.IN_VIVO_MODEL.value),
(("crispr",), ActionType.CRISPR_KNOCKOUT.value),
(("biomarker",), ActionType.BIOMARKER_CORRELATION.value),
(("red", "flag"), ActionType.FLAG_RED_FLAG.value),
(("review",), ActionType.REQUEST_EXPERT_REVIEW.value),
(("submit",), ActionType.SUBMIT_VALIDATION_REPORT.value),
(("report",), ActionType.SUBMIT_VALIDATION_REPORT.value),
]
for fragments, normalized in heuristics:
if all(fragment in candidate for fragment in fragments):
return normalized
return None
def parse_action(text: str) -> Optional[DrugTargetAction]:
d = extract_json_object(text)
if d is None:
return None
action_type = normalize_action_type(get_payload_value(d, "action_type"))
if action_type is None:
return None
parameters = get_payload_value(d, "parameters", "params") or {}
if not isinstance(parameters, dict):
parameters = {}
raw_conf = get_payload_value(d, "confidence")
confidence: Optional[float]
if raw_conf is None:
confidence = None
else:
try:
confidence = float(raw_conf)
confidence = max(0.0, min(1.0, confidence))
except (TypeError, ValueError):
confidence = None
reasoning = get_payload_value(
d, "reasoning", "rationale", "justification", "reason"
)
if reasoning is not None and not isinstance(reasoning, str):
reasoning = compact_preview(reasoning, 200)
reasoning = reasoning or ""
final_decision = normalize_optional_string(
get_payload_value(d, "final_decision", "decision", "go_no_go")
)
if final_decision is not None:
final_decision = final_decision.lower().replace("-", "_")
if final_decision not in {"go", "no_go"}:
final_decision = None
return DrugTargetAction(
action_type=ActionType(action_type),
parameters=parameters,
reasoning=reasoning,
final_decision=final_decision,
confidence=confidence,
)
def ensure_terminal_payload(action: DrugTargetAction) -> DrugTargetAction:
"""Make sure a SUBMIT_VALIDATION_REPORT carries a decision + confidence."""
if action.action_type != ActionType.SUBMIT_VALIDATION_REPORT:
return action
decision = action.final_decision or "no_go"
confidence = action.confidence if action.confidence is not None else 0.5
return action.model_copy(update={
"final_decision": decision,
"confidence": float(max(0.0, min(1.0, confidence))),
})
# ── Dashboard state writer ────────────────────────────────────────────
def write_dashboard_state(
env: DrugTargetEnvironment,
obs: ValidationObservation,
*,
step: int,
cumulative_reward: float,
model_response: str = "",
model_thinking: str = "",
action: Optional[DrugTargetAction] = None,
gen_time: float = 0.0,
episode_done: bool = False,
) -> None:
latent = env._latent
snapshot: Dict[str, Any] = {
"timestamp": time.time(),
"step": step,
"episode_done": episode_done,
"cumulative_reward": cumulative_reward,
"gen_time_s": round(gen_time, 2),
"model_response_raw": model_response[:600],
"model_thinking": model_thinking[:800],
"thinking_enabled": ENABLE_THINKING,
}
snapshot["task"] = {
"target_gene": obs.target_gene,
"indication": obs.indication,
"disease_context": obs.disease_context,
"credits_total": obs.credits_total,
}
snapshot["resources"] = {
"credits_used": obs.credits_total - obs.credits_remaining,
"credits_remaining": obs.credits_remaining,
"credits_total": obs.credits_total,
}
snapshot["pipeline_history"] = [
{
"step_index": h.get("step_index"),
"action_type": h.get("action_type"),
"output_summary": str(h.get("output_summary", ""))[:120],
"success": h.get("success"),
"quality_score": round(h.get("quality_score", 0.0), 3),
"credit_cost": h.get("credit_cost", 0),
}
for h in obs.pipeline_history
]
if action:
snapshot["current_action"] = {
"action_type": action.action_type.value,
"parameters": action.parameters,
"reasoning": action.reasoning,
"final_decision": action.final_decision,
"confidence": action.confidence,
}
if obs.latest_output:
lo = obs.latest_output
snapshot["latest_output"] = {
"summary": lo.summary,
"success": lo.success,
"quality_score": round(lo.quality_score, 3),
"uncertainty": round(lo.uncertainty, 3),
"warnings": lo.warnings,
"data_preview": compact_preview(lo.data, 300) if lo.data else None,
}
snapshot["dossier"] = obs.dossier.model_dump()
snapshot["rule_violations"] = obs.rule_violations
snapshot["reward_breakdown"] = {
k: round(v, 4) for k, v in obs.step_reward_breakdown.items()
}
if latent:
t = latent.target
snapshot["latent"] = {
"target_profile": {
"expression_level": t.expression_level,
"tissue_specificity": round(t.tissue_specificity, 3),
"disease_overexpression": round(t.disease_overexpression, 3),
"druggability_score": round(t.druggability_score, 3),
"binding_pocket_quality": t.binding_pocket_quality,
"has_known_ligands": t.has_known_ligands,
"allosteric_site_available": t.allosteric_site_available,
"selectivity_ratio": round(t.selectivity_ratio, 3),
"off_target_count": t.off_target_count,
"off_target_genes": t.off_target_genes,
"toxicity_profile": t.toxicity_profile,
"toxicity_tissues": t.toxicity_tissues,
"clinical_precedent": t.clinical_precedent,
"clinical_stage_reached": t.clinical_stage_reached,
"competitor_programs": t.competitor_programs,
"true_viability_score": round(t.true_viability_score, 3),
"correct_decision": t.correct_decision,
"key_evidence_dimensions": t.key_evidence_dimensions,
"misleading_signals": t.misleading_signals,
},
"data_quality": latent.data_quality.model_dump(),
"credits": {
"credits_used": latent.credits.credits_used,
"credits_total": latent.credits.credits_total,
"credits_remaining": latent.credits.credits_remaining,
},
"progress": latent.progress.model_dump(),
"action_call_counts": latent.action_call_counts,
}
try:
DASHBOARD_STATE_PATH.write_text(
json.dumps(snapshot, indent=2, default=str), encoding="utf-8",
)
except Exception:
pass
def log(msg: str) -> None:
print(msg, flush=True)
def build_observation_prompt(obs: ValidationObservation) -> str:
return format_observation(obs)
def run_with_pipeline(pipe, prompt: str) -> str:
try:
_pipe_max = 2048 if ENABLE_THINKING else 300
result = pipe(prompt, max_new_tokens=_pipe_max, return_full_text=False)
except Exception:
return ""
if isinstance(result, list) and result:
result = result[0]
if isinstance(result, dict):
text = result.get("generated_text") or result.get("text") or result.get("answer")
elif isinstance(result, str):
text = result
else:
text = ""
return text.strip() if isinstance(text, str) else ""
def resolve_torch_runtime() -> Dict[str, Any]:
use_cuda = torch.cuda.is_available()
bf16 = bool(getattr(torch.cuda, "is_bf16_supported", lambda: False)()) if use_cuda else False
dtype = torch.bfloat16 if bf16 else (
torch.float16 if use_cuda else torch.float32
)
return {
"use_cuda": use_cuda,
"device": "cuda:0" if use_cuda else "cpu",
"dtype": dtype,
"device_map": "auto" if use_cuda else None,
"device_name": torch.cuda.get_device_name(0) if use_cuda else "cpu",
}
def main():
tokenizer = None
model = None
eos_ids: List[int] = []
active_pipeline = None
runtime = resolve_torch_runtime()
log(
f"Using local model runtime: device={runtime['device']} "
f"name={runtime['device_name']} dtype={runtime['dtype']}"
)
if USE_PIPELINE:
log(f"Loading pipeline ({PIPELINE_TASK}) for {MODEL_ID} ...")
try:
active_pipeline = pipeline(
PIPELINE_TASK,
model=MODEL_ID,
trust_remote_code=True,
dtype=runtime["dtype"],
device=0 if runtime["use_cuda"] else -1,
)
log("Pipeline loaded.")
except Exception as exc:
log(f"Pipeline load failed ({exc}), falling back to tokenizer+model.")
if active_pipeline is None:
log(f"Loading tokenizer for {MODEL_ID} ...")
tokenizer = AutoTokenizer.from_pretrained(
MODEL_ID, trust_remote_code=True,
)
log("Tokenizer loaded. Loading model (this may download files on first run) ...")
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
dtype=runtime["dtype"],
device_map=runtime["device_map"],
trust_remote_code=True,
)
log(f"Model loaded. Device: {model.device}")
if tokenizer.eos_token_id is not None:
eos_ids.append(tokenizer.eos_token_id)
extra = tokenizer.convert_tokens_to_ids(["<|im_end|>", "<|endoftext|>"])
for tid in extra:
if isinstance(tid, int) and tid not in eos_ids:
eos_ids.append(tid)
log(f"EOS token ids: {eos_ids}")
def check_dashboard_command() -> Optional[Dict[str, Any]]:
try:
raw = DASHBOARD_CMD_PATH.read_text(encoding="utf-8")
try:
DASHBOARD_CMD_PATH.unlink(missing_ok=True)
except OSError:
pass
return json.loads(raw)
except (FileNotFoundError, json.JSONDecodeError):
return None
def run_episode(scenario_name: Optional[str] = None):
env = DrugTargetEnvironment(scenario_name=scenario_name)
obs = env.reset()
log("\n" + "=" * 70)
log(
f"TARGET: {obs.target_gene} | INDICATION: {obs.indication} | "
f"Credits: {obs.credits_total}"
)
if ENABLE_THINKING:
log("Reasoning mode: ENABLED")
log("=" * 70)
cumulative_reward = 0.0
write_dashboard_state(env, obs, step=0, cumulative_reward=0.0)
for step in range(MAX_EPISODE_STEPS):
cmd = check_dashboard_command()
if cmd and cmd.get("action") == "restart":
log("\n[DASHBOARD] Restart requested — ending episode early.")
break
user_msg = build_observation_prompt(obs)
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": user_msg},
]
if active_pipeline is not None:
prompt = f"{SYSTEM_PROMPT}\n\n{user_msg}"
else:
try:
prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=ENABLE_THINKING,
)
except TypeError:
prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
t0 = time.time()
if active_pipeline is not None:
response = run_with_pipeline(active_pipeline, prompt)
if not response:
response = format_observation(obs)
else:
assert tokenizer is not None and model is not None
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
n_input = inputs["input_ids"].shape[1]
max_new = 2048 if ENABLE_THINKING else 300
with torch.no_grad():
output_ids = model.generate(
**inputs,
max_new_tokens=max_new,
do_sample=True,
temperature=0.7,
top_p=0.8,
top_k=20,
repetition_penalty=1.3,
eos_token_id=eos_ids if eos_ids else None,
)
new_tokens = output_ids[0][n_input:]
response = tokenizer.decode(new_tokens, skip_special_tokens=True).strip()
gen_time = time.time() - t0
thinking = ""
if ENABLE_THINKING:
think_match = re.search(
r"<think>(.*?)</think>", response, re.DOTALL
)
if think_match:
thinking = think_match.group(1).strip()
response = response[think_match.end():].strip()
elif response.startswith("<think>"):
parts = response.split("</think>", 1)
if len(parts) == 2:
thinking = parts[0].replace("<think>", "").strip()
response = parts[1].strip()
is_last_step = (step == MAX_EPISODE_STEPS - 1)
action = parse_action(response)
if action is None:
if is_last_step:
log(
"\n [!] Parse failed on final step — forcing "
"submit_validation_report."
)
action = DrugTargetAction(
action_type=ActionType.SUBMIT_VALIDATION_REPORT,
reasoning="forced terminal report",
final_decision="no_go",
confidence=0.5,
)
else:
log(f"\n [!] Parse failed, skipping step. Raw: {response[:150]}")
continue
if is_last_step and action.action_type != ActionType.SUBMIT_VALIDATION_REPORT:
log(
f"\n [!] Final step — overriding {action.action_type.value} "
f"with submit_validation_report."
)
action = DrugTargetAction(
action_type=ActionType.SUBMIT_VALIDATION_REPORT,
reasoning="forced terminal report",
final_decision="no_go",
confidence=action.confidence or 0.5,
)
action = ensure_terminal_payload(action)
log(f"\nStep {step + 1}: {action.action_type.value} ({gen_time:.1f}s)")
if thinking:
log(f" Thinking: {thinking[:200]}")
if action.reasoning:
log(f" Reasoning: {action.reasoning}")
else:
log(" Reasoning: [model did not provide one]")
if action.parameters:
log(f" Parameters: {compact_preview(action.parameters, 200)}")
obs = env.step(action)
if obs.latest_output:
lo = obs.latest_output
status = "OK" if lo.success else "FAIL"
log(f" [{status}] {lo.summary}")
if lo.warnings:
log(f" Warnings: {lo.warnings}")
step_reward = obs.reward
cumulative_reward += step_reward
log(f" Reward: {step_reward:+.3f} (cum: {cumulative_reward:+.3f})")
log(
f" Credits remaining: {obs.credits_remaining}"
f"/{obs.credits_total}"
)
write_dashboard_state(
env, obs,
step=step + 1,
cumulative_reward=cumulative_reward,
model_response=response,
model_thinking=thinking,
action=action,
gen_time=gen_time,
episode_done=obs.done,
)
if obs.rule_violations:
log(f" Violations: {obs.rule_violations}")
if obs.done:
break
log(f"\n{'=' * 70}")
log("EPISODE COMPLETE" if obs.done else f"MAX STEPS ({MAX_EPISODE_STEPS})")
log(f" Steps: {obs.step_index}")
log(f" Total reward: {cumulative_reward:+.3f}")
log(
f" Credits used: {obs.credits_total - obs.credits_remaining}"
f"/{obs.credits_total}"
)
log("=" * 70)
try:
DASHBOARD_CMD_PATH.unlink(missing_ok=True)
except OSError:
pass
run_episode()
while True:
log("\nWaiting for dashboard command (restart / new task) ...")
while True:
cmd = check_dashboard_command()
if cmd:
break
time.sleep(1.0)
action_type = cmd.get("action", "restart")
if action_type == "quit":
log("Quit requested.")
break
scenario = cmd.get("scenario_name")
log(f"\n[DASHBOARD] {action_type} — scenario={scenario}")
run_episode(scenario_name=scenario)
if __name__ == "__main__":
main()