Spaces:
Sleeping
Sleeping
File size: 7,106 Bytes
f2beac3 60c0453 f2beac3 60c0453 f2beac3 60c0453 f2beac3 60c0453 f2beac3 60c0453 f2beac3 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 | import json
import os
import sys
from typing import Optional
from openai import OpenAI
try:
from .env import Action
except ImportError:
from env import Action
_cached_client: Optional[OpenAI] = None
_cached_model = os.environ.get("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
def _maybe_get_client() -> Optional[OpenAI]:
global _cached_client
if _cached_client is not None:
return _cached_client
base_url = os.environ.get("API_BASE_URL", "").strip()
api_key = os.environ.get("HF_TOKEN") or os.environ.get("API_KEY") or "hf-missing-token"
if not base_url:
print(
"[WARN] API_BASE_URL is not configured; AnalystAgent will use heuristic mode.",
file=sys.stderr,
)
return None
_cached_client = OpenAI(base_url=base_url, api_key=api_key)
return _cached_client
class AnalystAgent:
"""
Lightweight pharmacovigilance agent for demos and smoke testing.
The agent can call an OpenAI-compatible chat endpoint when configured, but
it also has a deterministic fallback policy for offline or local use.
"""
def __init__(self) -> None:
self.review_memory: list[dict] = []
def _case_snapshot(self, observation) -> str:
report_lines = []
for report in observation.reports:
report_lines.append(
f"- {report.report_id}: suspect={report.suspect_drug}, "
f"reaction={report.reaction}, onset_days={report.onset_days}, "
f"severity={report.severity}, outcome={report.outcome}, "
f"similar_30d={report.similar_reports_last_30d}"
)
memory_block = ""
if self.review_memory:
memory_block = "\nRecent mistakes to avoid:\n"
for item in self.review_memory[-3:]:
memory_block += (
f"- On {item['task_id']} you underperformed after choosing "
f"{item['classification']} / {item['recommended_action']}. "
f"Reason note: {item['note']}\n"
)
return (
f"Task id: {observation.task_id}\n"
f"Reports:\n" + "\n".join(report_lines) + "\n"
f"Knowledge base:\n{json.dumps(observation.drug_interaction_db, ensure_ascii=True, indent=2)}"
f"{memory_block}"
)
def _build_prompt(self, observation) -> str:
return f"""You are a pharmacovigilance case assessor.
Review the case below and return one JSON object only.
Return fields:
- classification: one of new_signal, known_side_effect, noise, duplicate
- suspect_drug: likely causal drug or interaction
- severity_assessment: one of mild, moderate, severe, critical
- recommended_action: one of escalate, log_and_monitor, dismiss, request_more_info
- reasoning: concise mechanistic explanation
- confidence: integer from 0 to 100 indicating how confident you are
Decision principles:
- Repeated known labeled reactions should usually be known_side_effect
- Small but coherent post-marketing clusters on a newer drug can justify new_signal
- If the reporter blames the wrong medication, prefer the stronger causal interaction
- Missing a serious signal is worse than overcalling a weak case
Case:
{self._case_snapshot(observation)}
"""
def _llm_decision(self, observation) -> Optional[Action]:
client = _maybe_get_client()
if client is None:
return None
try:
response = client.chat.completions.create(
model=_cached_model,
messages=[{"role": "user", "content": self._build_prompt(observation)}],
temperature=0.0,
max_tokens=220,
)
raw = (response.choices[0].message.content or "").strip()
payload = json.loads(raw)
return Action(**payload)
except Exception as exc:
print(f"[WARN] AnalystAgent LLM path failed: {exc}; falling back to heuristics.", file=sys.stderr)
return None
def _heuristic_decision(self, observation) -> Action:
reports = observation.reports
report_count = len(reports)
report = reports[0]
reaction_blob = " ".join(item.reaction.lower() for item in reports)
db_blob = json.dumps(observation.drug_interaction_db).lower()
if "dry cough" in reaction_blob and "ace inhibitor" in db_blob:
return Action(
classification="known_side_effect",
suspect_drug="Lisinopril",
severity_assessment="mild",
recommended_action="log_and_monitor",
reasoning="Persistent dry cough is a classic labeled ACE inhibitor effect.",
confidence=92,
)
if report_count >= 3 and ("brady" in reaction_blob or "syncope" in reaction_blob):
return Action(
classification="new_signal",
suspect_drug="Cardiovexa",
severity_assessment="severe",
recommended_action="escalate",
reasoning="A coherent cluster of bradycardia reports on a recently launched drug warrants escalation.",
confidence=88,
)
if "tacrolimus" in db_blob and "voriconazole" in db_blob:
return Action(
classification="new_signal",
suspect_drug="Tacrolimus+Voriconazole",
severity_assessment="critical",
recommended_action="escalate",
reasoning="This looks like a tacrolimus exposure interaction requiring urgent escalation and level review.",
confidence=84,
)
fallback_severity = report.severity if report.severity in {"mild", "moderate", "severe", "critical"} else "moderate"
return Action(
classification="new_signal",
suspect_drug=report.suspect_drug,
severity_assessment=fallback_severity,
recommended_action="request_more_info",
reasoning="The case is ambiguous, so additional information is needed before final triage.",
confidence=45,
)
def act(self, observation) -> Action:
llm_action = self._llm_decision(observation)
if llm_action is not None:
return llm_action
return self._heuristic_decision(observation)
def learn(self, action: Action, observation) -> None:
reward = getattr(observation, "reward", 0.0)
if reward is None:
reward = 0.0
if reward < 0.5:
self.review_memory.append(
{
"task_id": getattr(observation, "task_id", "unknown"),
"classification": action.classification,
"recommended_action": action.recommended_action,
"note": getattr(observation, "feedback", "") or "weak outcome",
}
)
|