import json import os import sys from typing import Optional from openai import OpenAI try: from .env import Action except ImportError: from env import Action _cached_client: Optional[OpenAI] = None _cached_model = os.environ.get("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct") def _maybe_get_client() -> Optional[OpenAI]: global _cached_client if _cached_client is not None: return _cached_client base_url = os.environ.get("API_BASE_URL", "").strip() api_key = os.environ.get("HF_TOKEN") or os.environ.get("API_KEY") or "hf-missing-token" if not base_url: print( "[WARN] API_BASE_URL is not configured; AnalystAgent will use heuristic mode.", file=sys.stderr, ) return None _cached_client = OpenAI(base_url=base_url, api_key=api_key) return _cached_client class AnalystAgent: """ Lightweight pharmacovigilance agent for demos and smoke testing. The agent can call an OpenAI-compatible chat endpoint when configured, but it also has a deterministic fallback policy for offline or local use. """ def __init__(self) -> None: self.review_memory: list[dict] = [] def _case_snapshot(self, observation) -> str: report_lines = [] for report in observation.reports: report_lines.append( f"- {report.report_id}: suspect={report.suspect_drug}, " f"reaction={report.reaction}, onset_days={report.onset_days}, " f"severity={report.severity}, outcome={report.outcome}, " f"similar_30d={report.similar_reports_last_30d}" ) memory_block = "" if self.review_memory: memory_block = "\nRecent mistakes to avoid:\n" for item in self.review_memory[-3:]: memory_block += ( f"- On {item['task_id']} you underperformed after choosing " f"{item['classification']} / {item['recommended_action']}. " f"Reason note: {item['note']}\n" ) return ( f"Task id: {observation.task_id}\n" f"Reports:\n" + "\n".join(report_lines) + "\n" f"Knowledge base:\n{json.dumps(observation.drug_interaction_db, ensure_ascii=True, indent=2)}" f"{memory_block}" ) def _build_prompt(self, observation) -> str: return f"""You are a pharmacovigilance case assessor. Review the case below and return one JSON object only. Return fields: - classification: one of new_signal, known_side_effect, noise, duplicate - suspect_drug: likely causal drug or interaction - severity_assessment: one of mild, moderate, severe, critical - recommended_action: one of escalate, log_and_monitor, dismiss, request_more_info - reasoning: concise mechanistic explanation - confidence: integer from 0 to 100 indicating how confident you are Decision principles: - Repeated known labeled reactions should usually be known_side_effect - Small but coherent post-marketing clusters on a newer drug can justify new_signal - If the reporter blames the wrong medication, prefer the stronger causal interaction - Missing a serious signal is worse than overcalling a weak case Case: {self._case_snapshot(observation)} """ def _llm_decision(self, observation) -> Optional[Action]: client = _maybe_get_client() if client is None: return None try: response = client.chat.completions.create( model=_cached_model, messages=[{"role": "user", "content": self._build_prompt(observation)}], temperature=0.0, max_tokens=220, ) raw = (response.choices[0].message.content or "").strip() payload = json.loads(raw) return Action(**payload) except Exception as exc: print(f"[WARN] AnalystAgent LLM path failed: {exc}; falling back to heuristics.", file=sys.stderr) return None def _heuristic_decision(self, observation) -> Action: reports = observation.reports report_count = len(reports) report = reports[0] reaction_blob = " ".join(item.reaction.lower() for item in reports) db_blob = json.dumps(observation.drug_interaction_db).lower() if "dry cough" in reaction_blob and "ace inhibitor" in db_blob: return Action( classification="known_side_effect", suspect_drug="Lisinopril", severity_assessment="mild", recommended_action="log_and_monitor", reasoning="Persistent dry cough is a classic labeled ACE inhibitor effect.", confidence=92, ) if report_count >= 3 and ("brady" in reaction_blob or "syncope" in reaction_blob): return Action( classification="new_signal", suspect_drug="Cardiovexa", severity_assessment="severe", recommended_action="escalate", reasoning="A coherent cluster of bradycardia reports on a recently launched drug warrants escalation.", confidence=88, ) if "tacrolimus" in db_blob and "voriconazole" in db_blob: return Action( classification="new_signal", suspect_drug="Tacrolimus+Voriconazole", severity_assessment="critical", recommended_action="escalate", reasoning="This looks like a tacrolimus exposure interaction requiring urgent escalation and level review.", confidence=84, ) fallback_severity = report.severity if report.severity in {"mild", "moderate", "severe", "critical"} else "moderate" return Action( classification="new_signal", suspect_drug=report.suspect_drug, severity_assessment=fallback_severity, recommended_action="request_more_info", reasoning="The case is ambiguous, so additional information is needed before final triage.", confidence=45, ) def act(self, observation) -> Action: llm_action = self._llm_decision(observation) if llm_action is not None: return llm_action return self._heuristic_decision(observation) def learn(self, action: Action, observation) -> None: reward = getattr(observation, "reward", 0.0) if reward is None: reward = 0.0 if reward < 0.5: self.review_memory.append( { "task_id": getattr(observation, "task_id", "unknown"), "classification": action.classification, "recommended_action": action.recommended_action, "note": getattr(observation, "feedback", "") or "weak outcome", } )