dikheng commited on
Commit
f89a5cf
·
1 Parent(s): 432592c

refactor: system/user prompt split + robust JSON extraction

Browse files

- generate_response now accepts system_prompt + user_prompt separately
(proper roles in chat completions); exposes max_tokens, temperature,
force_json params per agent call
- _extract_json helper: strips markdown fences, falls back to scanning
balanced braces — eliminates JSON parse failures in production
- vision/clinical/format/chat agents updated to use new signature with
per-step token and temperature budgets
- prompts rewritten for cleaner system instructions and stricter JSON
schema enforcement

Files changed (5) hide show
  1. app.py +101 -12
  2. src/agents.py +130 -30
  3. src/inference.py +12 -8
  4. src/model_loader.py +46 -19
  5. src/prompts.py +129 -57
app.py CHANGED
@@ -659,11 +659,102 @@ def _empty_soap_html(lang: str) -> str:
659
  return _build_soap_html("", lang)
660
 
661
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
662
  def _build_result_html(result: dict, lang: str) -> str:
663
  t = _I18N.get(lang, _I18N["en"])
664
  triage = result.get("triage_level", "Low")
665
  patient_msg = result.get("patient_message", "")
666
  conditions = result.get("possible_conditions", [])
 
 
 
667
  metrics = result.get("_metrics", {})
668
 
669
  backend_tag = (
@@ -672,7 +763,6 @@ def _build_result_html(result: dict, lang: str) -> str:
672
  "border:1px solid #16a34a;'>AMD Cloud</span>"
673
  )
674
 
675
- # Triage color
676
  triage_colors = {
677
  "High": ("#ef4444", "#7f1d1d"),
678
  "Medium": ("#f97316", "#431407"),
@@ -680,7 +770,6 @@ def _build_result_html(result: dict, lang: str) -> str:
680
  }
681
  t_color, t_bg = triage_colors.get(triage, ("#22c55e", "#052e16"))
682
 
683
- # Red-flag flashing banner
684
  critical_banner = ""
685
  if triage == "High":
686
  critical_banner = f"""
@@ -692,19 +781,14 @@ def _build_result_html(result: dict, lang: str) -> str:
692
  </span>
693
  </div>"""
694
 
695
- # Possible conditions chips
696
- cond_chips = "".join(
697
- f"<span style='background:#1e3a5f; color:#93c5fd; font-size:0.72rem; "
698
- f"padding:3px 10px; border-radius:999px; border:1px solid #2563eb;'>{c}</span>"
699
- for c in conditions
700
- ) if conditions else "<span style='color:#6b7280;'>—</span>"
701
-
702
- # Patient message paragraphs
703
  msg_html = "".join(
704
  f"<p style='margin:0 0 8px; color:#d1d5db; line-height:1.6;'>{line}</p>"
705
  for line in patient_msg.split("\n") if line.strip()
706
  ) if patient_msg else "<p style='color:#6b7280;'>—</p>"
707
 
 
 
 
708
  return f"""
709
  <div style='background:#111827; border:1px solid #ED1C24; border-radius:12px;
710
  padding:clamp(14px,4vw,20px); font-family:Arial,sans-serif; color:#f9fafb;
@@ -737,10 +821,12 @@ def _build_result_html(result: dict, lang: str) -> str:
737
 
738
  <div style='background:#1f2937; border-radius:8px; padding:14px; margin-bottom:12px;'>
739
  <div style='font-size:0.72rem; text-transform:uppercase; letter-spacing:.05em;
740
- color:#9ca3af; margin-bottom:8px;'>{t['conditions_label']}</div>
741
- <div style='display:flex; flex-wrap:wrap; gap:6px;'>{cond_chips}</div>
742
  </div>
743
 
 
 
744
  <div style='background:#1f2937; border-radius:8px; padding:14px; margin-bottom:12px;'>
745
  <div style='font-size:0.72rem; text-transform:uppercase; letter-spacing:.05em;
746
  color:#9ca3af; margin-bottom:8px;'>
@@ -944,6 +1030,9 @@ def predict(image_1, image_2, symptoms: str, lang_choice: str, selected_regions)
944
  "visual_description": result.get("visual_description", ""),
945
  "possible_conditions": result.get("possible_conditions", []),
946
  "triage_level": result.get("triage_level", "Low"),
 
 
 
947
  "patient_message": patient_msg,
948
  }
949
  return (
 
659
  return _build_soap_html("", lang)
660
 
661
 
662
+ def _condition_probability_bars(conditions: list, t: dict) -> str:
663
+ """Render probability bars for each possible condition."""
664
+ if not conditions:
665
+ return "<span style='color:#6b7280;'>—</span>"
666
+
667
+ bars = []
668
+ for c in conditions:
669
+ if isinstance(c, dict):
670
+ name = c.get("name", "Unknown")
671
+ prob = int(c.get("probability", 50))
672
+ icd10 = c.get("icd10", "")
673
+ else:
674
+ name, prob, icd10 = str(c), 50, ""
675
+
676
+ fill = "#ef4444" if prob >= 70 else "#f97316" if prob >= 45 else "#eab308"
677
+ icd_badge = (
678
+ f"<span style='font-size:0.6rem; color:#6b7280; background:#0f172a; "
679
+ f"padding:1px 5px; border-radius:3px; margin-left:4px; font-family:monospace;'>"
680
+ f"{icd10}</span>"
681
+ ) if icd10 else ""
682
+
683
+ bars.append(
684
+ f"<div style='margin-bottom:10px;'>"
685
+ f" <div style='display:flex; align-items:center; justify-content:space-between; margin-bottom:3px;'>"
686
+ f" <span style='font-size:0.8rem; color:#e2e8f0; font-weight:600;'>{name}{icd_badge}</span>"
687
+ f" <span style='font-size:0.75rem; color:{fill}; font-weight:700;'>{prob}%</span>"
688
+ f" </div>"
689
+ f" <div style='background:#374151; border-radius:9999px; height:7px; overflow:hidden;'>"
690
+ f" <div style='background:{fill}; width:{prob}%; height:100%; border-radius:9999px; "
691
+ f" transition:width 0.7s ease;'></div>"
692
+ f" </div>"
693
+ f"</div>"
694
+ )
695
+ return "".join(bars)
696
+
697
+
698
+ def _red_flags_panel(red_flags: list, watch_symptoms: list, urgency_reason: str) -> str:
699
+ """Render red flags and watch symptoms warning panel. Returns empty string if nothing to show."""
700
+ has_flags = bool(red_flags)
701
+ has_watch = bool(watch_symptoms)
702
+ has_urgency = bool(urgency_reason)
703
+ if not has_flags and not has_watch and not has_urgency:
704
+ return ""
705
+
706
+ flags_html = ""
707
+ if has_flags:
708
+ items = "".join(
709
+ f"<li style='margin:3px 0; color:#fca5a5;'>&#9888; {f}</li>"
710
+ for f in red_flags
711
+ )
712
+ flags_html = (
713
+ f"<div style='font-size:0.72rem; color:#ef4444; font-weight:700; "
714
+ f"text-transform:uppercase; letter-spacing:.04em; margin-bottom:6px;'>Red Flags</div>"
715
+ f"<ul style='margin:0 0 10px; padding-left:18px; list-style:none;'>{items}</ul>"
716
+ )
717
+
718
+ watch_html = ""
719
+ if has_watch:
720
+ items = "".join(
721
+ f"<li style='margin:3px 0; color:#fde68a;'>&#128065; {w}</li>"
722
+ for w in watch_symptoms
723
+ )
724
+ watch_html = (
725
+ f"<div style='font-size:0.72rem; color:#f59e0b; font-weight:700; "
726
+ f"text-transform:uppercase; letter-spacing:.04em; margin-bottom:6px;'>Watch For</div>"
727
+ f"<ul style='margin:0; padding-left:18px; list-style:none;'>{items}</ul>"
728
+ )
729
+
730
+ urgency_html = ""
731
+ if has_urgency:
732
+ urgency_html = (
733
+ f"<div style='font-size:0.75rem; color:#9ca3af; font-style:italic; "
734
+ f"border-top:1px solid #374151; padding-top:8px; margin-top:8px;'>"
735
+ f"&#9432; {urgency_reason}</div>"
736
+ )
737
+
738
+ border_color = "#ef4444" if has_flags else "#f59e0b"
739
+ bg_color = "#1c0a0a" if has_flags else "#1c1000"
740
+
741
+ return (
742
+ f"<div style='background:{bg_color}; border:1px solid {border_color}; "
743
+ f"border-left:4px solid {border_color}; border-radius:8px; "
744
+ f"padding:12px 14px; margin-bottom:12px;'>"
745
+ f"{flags_html}{watch_html}{urgency_html}"
746
+ f"</div>"
747
+ )
748
+
749
+
750
  def _build_result_html(result: dict, lang: str) -> str:
751
  t = _I18N.get(lang, _I18N["en"])
752
  triage = result.get("triage_level", "Low")
753
  patient_msg = result.get("patient_message", "")
754
  conditions = result.get("possible_conditions", [])
755
+ red_flags = result.get("red_flags", [])
756
+ watch_symptoms = result.get("watch_symptoms", [])
757
+ urgency_reason = result.get("urgency_reason", "")
758
  metrics = result.get("_metrics", {})
759
 
760
  backend_tag = (
 
763
  "border:1px solid #16a34a;'>AMD Cloud</span>"
764
  )
765
 
 
766
  triage_colors = {
767
  "High": ("#ef4444", "#7f1d1d"),
768
  "Medium": ("#f97316", "#431407"),
 
770
  }
771
  t_color, t_bg = triage_colors.get(triage, ("#22c55e", "#052e16"))
772
 
 
773
  critical_banner = ""
774
  if triage == "High":
775
  critical_banner = f"""
 
781
  </span>
782
  </div>"""
783
 
 
 
 
 
 
 
 
 
784
  msg_html = "".join(
785
  f"<p style='margin:0 0 8px; color:#d1d5db; line-height:1.6;'>{line}</p>"
786
  for line in patient_msg.split("\n") if line.strip()
787
  ) if patient_msg else "<p style='color:#6b7280;'>—</p>"
788
 
789
+ cond_bars = _condition_probability_bars(conditions, t)
790
+ alert_panel = _red_flags_panel(red_flags, watch_symptoms, urgency_reason)
791
+
792
  return f"""
793
  <div style='background:#111827; border:1px solid #ED1C24; border-radius:12px;
794
  padding:clamp(14px,4vw,20px); font-family:Arial,sans-serif; color:#f9fafb;
 
821
 
822
  <div style='background:#1f2937; border-radius:8px; padding:14px; margin-bottom:12px;'>
823
  <div style='font-size:0.72rem; text-transform:uppercase; letter-spacing:.05em;
824
+ color:#9ca3af; margin-bottom:10px;'>{t['conditions_label']}</div>
825
+ {cond_bars}
826
  </div>
827
 
828
+ {alert_panel}
829
+
830
  <div style='background:#1f2937; border-radius:8px; padding:14px; margin-bottom:12px;'>
831
  <div style='font-size:0.72rem; text-transform:uppercase; letter-spacing:.05em;
832
  color:#9ca3af; margin-bottom:8px;'>
 
1030
  "visual_description": result.get("visual_description", ""),
1031
  "possible_conditions": result.get("possible_conditions", []),
1032
  "triage_level": result.get("triage_level", "Low"),
1033
+ "urgency_reason": result.get("urgency_reason", ""),
1034
+ "red_flags": result.get("red_flags", []),
1035
+ "watch_symptoms": result.get("watch_symptoms", []),
1036
  "patient_message": patient_msg,
1037
  }
1038
  return (
src/agents.py CHANGED
@@ -2,7 +2,13 @@ import json
2
  import re
3
 
4
  from src.model_loader import generate_response, generate_text
5
- from src.prompts import VISION_AGENT_SYSTEM, CLINICAL_AGENT_SYSTEM, PATIENT_AGENT_SYSTEM, SOAP_AGENT_SYSTEM, CHAT_AGENT_SYSTEM
 
 
 
 
 
 
6
 
7
  _LANG_NAMES = {
8
  "en": "English",
@@ -13,73 +19,153 @@ _LANG_NAMES = {
13
  "ja": "Japanese",
14
  }
15
 
16
-
17
  _NO_IMAGE_DESC = "(No image provided — assessment based on patient symptom text only.)"
18
  _ZERO_METRICS = {"latency_ms": 0, "total_tokens": 0, "tokens_per_sec": 0}
19
 
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  def vision_agent(image_path_1, image_path_2, symptoms: str) -> tuple[str, dict]:
22
- """Step 1: strictly objective visual description. Returns (description_text, metrics)."""
 
 
 
23
  if not image_path_1 and not image_path_2:
24
  return _NO_IMAGE_DESC, _ZERO_METRICS.copy()
 
25
  two_images = bool(image_path_2)
26
- user_msg = VISION_AGENT_SYSTEM + "\n\n"
27
  if two_images:
28
- user_msg += "TWO images are provided: the first image is Day 1, the second image is Day X.\n\n"
29
- user_msg += f"Patient symptom text: {symptoms or '(none provided)'}"
30
- return generate_response(user_msg, image_path=image_path_1 or None,
31
- image_path_2=image_path_2 or None)
 
 
 
 
 
 
 
32
 
33
 
34
  def clinical_agent(visual_description: str, symptoms: str, lang: str = "en") -> tuple[dict, dict]:
35
- """Step 2: clinical reasoning → strict JSON. Returns (parsed_dict, metrics)."""
 
 
 
36
  lang_name = _LANG_NAMES.get(lang, "English")
37
- prompt = (
38
- CLINICAL_AGENT_SYSTEM + "\n\n"
39
- f"TARGET LANGUAGE FOR CONDITIONS: {lang_name}\n\n"
40
  f"VISUAL DESCRIPTION:\n{visual_description}\n\n"
41
  f"PATIENT SYMPTOMS:\n{symptoms or '(none provided)'}"
42
  )
43
- raw, metrics = generate_text(prompt)
44
- match = re.search(r'\{.*\}', raw, re.DOTALL)
45
- if not match:
46
- raise ValueError(f"Clinical agent did not return JSON: {raw[:300]}")
47
- data = json.loads(match.group())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  return {
49
  "triage_level": data.get("triage_level", "Low"),
50
- "possible_conditions": data.get("possible_conditions", []),
 
 
 
51
  "clinical_assessment": data.get("clinical_assessment", ""),
52
  "recommendation": data.get("recommendation", ""),
53
  }, metrics
54
 
55
 
56
  def chat_agent(question: str, context: dict, history: list, lang: str) -> tuple[str, dict]:
57
- """Follow-up Q&A. Returns (answer_text, metrics)."""
 
 
58
  lang_name = _LANG_NAMES.get(lang, "English")
 
 
 
 
 
 
59
  ctx_block = (
60
  f"ANALYSIS CONTEXT:\n"
61
  f"- Visual description: {context.get('visual_description', '(none)')}\n"
62
- f"- Possible conditions: {', '.join(context.get('possible_conditions', []))}\n"
63
  f"- Triage level: {context.get('triage_level', 'Low')}\n"
64
- f"- Patient message given: {context.get('patient_message', '(none)')}"
 
 
65
  )
 
66
  history_block = ""
67
  for user_msg, bot_msg in (history or []):
68
  history_block += f"\nPatient: {user_msg}\nAssistant: {bot_msg}"
69
- prompt = (
70
- CHAT_AGENT_SYSTEM + "\n\n"
71
  f"TARGET LANGUAGE: {lang_name}\n\n"
72
  f"{ctx_block}\n"
73
  f"{history_block}\n\n"
74
  f"Patient: {question}\nAssistant:"
75
  )
76
- answer, metrics = generate_text(prompt)
 
 
 
 
 
 
77
  return answer.strip(), metrics
78
 
79
 
80
  def format_agent(clinical_json: dict, visual_description: str,
81
  symptoms: str, lang: str) -> tuple[str, str, dict]:
82
- """Step 3a+3b: patient message and SOAP note as two separate LLM calls."""
 
 
 
83
  lang_name = _LANG_NAMES.get(lang, "English")
84
  context = (
85
  f"TARGET LANGUAGE: {lang_name}\n\n"
@@ -87,11 +173,25 @@ def format_agent(clinical_json: dict, visual_description: str,
87
  f"VISUAL DESCRIPTION (Objective):\n{visual_description}\n\n"
88
  f"CLINICAL JSON:\n{json.dumps(clinical_json, ensure_ascii=False, indent=2)}"
89
  )
90
- patient_msg, m3a = generate_text(PATIENT_AGENT_SYSTEM + "\n\n" + context)
91
- soap, m3b = generate_text(SOAP_AGENT_SYSTEM + "\n\n" + context)
 
 
 
 
 
 
 
 
 
 
 
 
92
  metrics = {
93
- "latency_ms": m3a["latency_ms"] + m3b["latency_ms"],
94
- "total_tokens": m3a["total_tokens"] + m3b["total_tokens"],
95
- "tokens_per_sec": round((m3a.get("tokens_per_sec", 0) + m3b.get("tokens_per_sec", 0)) / 2, 1),
 
 
96
  }
97
  return patient_msg.strip(), soap.strip(), metrics
 
2
  import re
3
 
4
  from src.model_loader import generate_response, generate_text
5
+ from src.prompts import (
6
+ VISION_AGENT_SYSTEM,
7
+ CLINICAL_AGENT_SYSTEM,
8
+ PATIENT_AGENT_SYSTEM,
9
+ SOAP_AGENT_SYSTEM,
10
+ CHAT_AGENT_SYSTEM,
11
+ )
12
 
13
  _LANG_NAMES = {
14
  "en": "English",
 
19
  "ja": "Japanese",
20
  }
21
 
 
22
  _NO_IMAGE_DESC = "(No image provided — assessment based on patient symptom text only.)"
23
  _ZERO_METRICS = {"latency_ms": 0, "total_tokens": 0, "tokens_per_sec": 0}
24
 
25
 
26
+ def _extract_json(raw: str) -> dict:
27
+ """Robustly extract first JSON object from LLM output, stripping markdown fences."""
28
+ cleaned = re.sub(r"^```(?:json)?\s*|\s*```$", "", raw.strip(), flags=re.MULTILINE)
29
+ try:
30
+ return json.loads(cleaned)
31
+ except json.JSONDecodeError:
32
+ pass
33
+ # Scan for first balanced {...} block
34
+ depth = 0
35
+ start = None
36
+ for i, ch in enumerate(cleaned):
37
+ if ch == "{":
38
+ if depth == 0:
39
+ start = i
40
+ depth += 1
41
+ elif ch == "}":
42
+ depth -= 1
43
+ if depth == 0 and start is not None:
44
+ try:
45
+ return json.loads(cleaned[start:i + 1])
46
+ except json.JSONDecodeError:
47
+ continue
48
+ raise ValueError(f"No valid JSON object found in response: {raw[:300]}")
49
+
50
+
51
  def vision_agent(image_path_1, image_path_2, symptoms: str) -> tuple[str, dict]:
52
+ """
53
+ Step 1: objective visual description.
54
+ Returns (description_text, metrics).
55
+ """
56
  if not image_path_1 and not image_path_2:
57
  return _NO_IMAGE_DESC, _ZERO_METRICS.copy()
58
+
59
  two_images = bool(image_path_2)
60
+ user_prompt = ""
61
  if two_images:
62
+ user_prompt += "TWO images provided: first image is Day 1 (baseline), second image is Day X (follow-up).\n\n"
63
+ user_prompt += f"Patient symptom report: {symptoms or '(none provided)'}\n\nAnalyze the image(s) as instructed."
64
+
65
+ return generate_response(
66
+ system_prompt=VISION_AGENT_SYSTEM,
67
+ user_prompt=user_prompt,
68
+ image_path=image_path_1 or None,
69
+ image_path_2=image_path_2 or None,
70
+ max_tokens=600,
71
+ temperature=0.0,
72
+ )
73
 
74
 
75
  def clinical_agent(visual_description: str, symptoms: str, lang: str = "en") -> tuple[dict, dict]:
76
+ """
77
+ Step 2: clinical reasoning → structured JSON with richer schema.
78
+ Returns (parsed_dict, metrics).
79
+ """
80
  lang_name = _LANG_NAMES.get(lang, "English")
81
+ user_prompt = (
82
+ f"TARGET LANGUAGE: {lang_name}\n\n"
 
83
  f"VISUAL DESCRIPTION:\n{visual_description}\n\n"
84
  f"PATIENT SYMPTOMS:\n{symptoms or '(none provided)'}"
85
  )
86
+
87
+ raw, metrics = generate_text(
88
+ system_prompt=CLINICAL_AGENT_SYSTEM,
89
+ user_prompt=user_prompt,
90
+ max_tokens=800,
91
+ temperature=0.0,
92
+ force_json=True,
93
+ )
94
+
95
+ data = _extract_json(raw)
96
+
97
+ # Normalise possible_conditions — support new {name, probability, icd10} schema
98
+ # and gracefully handle plain-string fallback from older model outputs
99
+ raw_conditions = data.get("possible_conditions", [])
100
+ conditions = []
101
+ for item in raw_conditions:
102
+ if isinstance(item, dict):
103
+ conditions.append({
104
+ "name": str(item.get("name", item.get("condition", "Unknown"))),
105
+ "probability": int(item.get("probability", item.get("match_probability", 50))),
106
+ "icd10": str(item.get("icd10", item.get("icd10_code", ""))),
107
+ })
108
+ elif isinstance(item, str):
109
+ conditions.append({"name": item, "probability": 50, "icd10": ""})
110
+
111
  return {
112
  "triage_level": data.get("triage_level", "Low"),
113
+ "urgency_reason": data.get("urgency_reason", ""),
114
+ "possible_conditions": conditions,
115
+ "red_flags": data.get("red_flags", []),
116
+ "watch_symptoms": data.get("watch_symptoms", []),
117
  "clinical_assessment": data.get("clinical_assessment", ""),
118
  "recommendation": data.get("recommendation", ""),
119
  }, metrics
120
 
121
 
122
  def chat_agent(question: str, context: dict, history: list, lang: str) -> tuple[str, dict]:
123
+ """
124
+ Follow-up Q&A. Returns (answer_text, metrics).
125
+ """
126
  lang_name = _LANG_NAMES.get(lang, "English")
127
+
128
+ conditions_text = ", ".join(
129
+ c["name"] if isinstance(c, dict) else c
130
+ for c in context.get("possible_conditions", [])
131
+ )
132
+
133
  ctx_block = (
134
  f"ANALYSIS CONTEXT:\n"
135
  f"- Visual description: {context.get('visual_description', '(none)')}\n"
136
+ f"- Possible conditions: {conditions_text}\n"
137
  f"- Triage level: {context.get('triage_level', 'Low')}\n"
138
+ f"- Urgency reason: {context.get('urgency_reason', '')}\n"
139
+ f"- Red flags: {'; '.join(context.get('red_flags', [])) or 'none'}\n"
140
+ f"- Patient message: {context.get('patient_message', '(none)')}"
141
  )
142
+
143
  history_block = ""
144
  for user_msg, bot_msg in (history or []):
145
  history_block += f"\nPatient: {user_msg}\nAssistant: {bot_msg}"
146
+
147
+ user_prompt = (
148
  f"TARGET LANGUAGE: {lang_name}\n\n"
149
  f"{ctx_block}\n"
150
  f"{history_block}\n\n"
151
  f"Patient: {question}\nAssistant:"
152
  )
153
+
154
+ answer, metrics = generate_text(
155
+ system_prompt=CHAT_AGENT_SYSTEM,
156
+ user_prompt=user_prompt,
157
+ max_tokens=300,
158
+ temperature=0.3,
159
+ )
160
  return answer.strip(), metrics
161
 
162
 
163
  def format_agent(clinical_json: dict, visual_description: str,
164
  symptoms: str, lang: str) -> tuple[str, str, dict]:
165
+ """
166
+ Step 3a + 3b: patient message and SOAP note as two separate LLM calls.
167
+ Returns (patient_message, soap_note, combined_metrics).
168
+ """
169
  lang_name = _LANG_NAMES.get(lang, "English")
170
  context = (
171
  f"TARGET LANGUAGE: {lang_name}\n\n"
 
173
  f"VISUAL DESCRIPTION (Objective):\n{visual_description}\n\n"
174
  f"CLINICAL JSON:\n{json.dumps(clinical_json, ensure_ascii=False, indent=2)}"
175
  )
176
+
177
+ patient_msg, m3a = generate_text(
178
+ system_prompt=PATIENT_AGENT_SYSTEM,
179
+ user_prompt=context,
180
+ max_tokens=500,
181
+ temperature=0.4,
182
+ )
183
+ soap, m3b = generate_text(
184
+ system_prompt=SOAP_AGENT_SYSTEM,
185
+ user_prompt=context,
186
+ max_tokens=600,
187
+ temperature=0.0,
188
+ )
189
+
190
  metrics = {
191
+ "latency_ms": m3a["latency_ms"] + m3b["latency_ms"],
192
+ "total_tokens": m3a["total_tokens"] + m3b["total_tokens"],
193
+ "tokens_per_sec": round(
194
+ (m3a.get("tokens_per_sec", 0) + m3b.get("tokens_per_sec", 0)) / 2, 1
195
+ ),
196
  }
197
  return patient_msg.strip(), soap.strip(), metrics
src/inference.py CHANGED
@@ -6,13 +6,14 @@ class MediVisionPipeline:
6
  lang: str = "en", region: str = "") -> dict:
7
  """
8
  Run the 3-step agentic pipeline:
9
- Step 1 — Vision Agent: objective visual description
10
- Step 2 — Clinical Agent: triage JSON
11
- Step 3 — Format Agent: patient message + SOAP note
12
 
13
  Returns dict with keys:
14
- triage_level, possible_conditions, patient_message,
15
- soap_note, visual_description, _metrics
 
16
  """
17
  symptoms_full = f"{'Region: ' + region + '. ' if region else ''}{symptoms}"
18
 
@@ -21,19 +22,22 @@ class MediVisionPipeline:
21
  patient_msg, soap, m3 = format_agent(clinical, visual_desc, symptoms_full, lang)
22
 
23
  metrics = {
24
- "latency_ms": m1["latency_ms"] + m2["latency_ms"] + m3["latency_ms"],
25
- "total_tokens": m1["total_tokens"] + m2["total_tokens"] + m3["total_tokens"],
26
  "tokens_per_sec": round(
27
  (m1.get("tokens_per_sec", 0) + m2.get("tokens_per_sec", 0) + m3.get("tokens_per_sec", 0)) / 3, 1
28
  ),
29
  }
30
  return {
31
  "triage_level": clinical["triage_level"],
 
32
  "possible_conditions": clinical["possible_conditions"],
 
 
 
33
  "patient_message": patient_msg,
34
  "soap_note": soap,
35
  "visual_description": visual_desc,
36
  "_metrics": metrics,
37
- # kept for follow-up chat context
38
  "_clinical": clinical,
39
  }
 
6
  lang: str = "en", region: str = "") -> dict:
7
  """
8
  Run the 3-step agentic pipeline:
9
+ Step 1 — Vision Agent: objective visual description
10
+ Step 2 — Clinical Agent: structured triage JSON
11
+ Step 3 — Format Agent: patient message + SOAP note
12
 
13
  Returns dict with keys:
14
+ triage_level, urgency_reason, possible_conditions,
15
+ red_flags, watch_symptoms, clinical_assessment,
16
+ patient_message, soap_note, visual_description, _metrics
17
  """
18
  symptoms_full = f"{'Region: ' + region + '. ' if region else ''}{symptoms}"
19
 
 
22
  patient_msg, soap, m3 = format_agent(clinical, visual_desc, symptoms_full, lang)
23
 
24
  metrics = {
25
+ "latency_ms": m1["latency_ms"] + m2["latency_ms"] + m3["latency_ms"],
26
+ "total_tokens": m1["total_tokens"] + m2["total_tokens"] + m3["total_tokens"],
27
  "tokens_per_sec": round(
28
  (m1.get("tokens_per_sec", 0) + m2.get("tokens_per_sec", 0) + m3.get("tokens_per_sec", 0)) / 3, 1
29
  ),
30
  }
31
  return {
32
  "triage_level": clinical["triage_level"],
33
+ "urgency_reason": clinical["urgency_reason"],
34
  "possible_conditions": clinical["possible_conditions"],
35
+ "red_flags": clinical["red_flags"],
36
+ "watch_symptoms": clinical["watch_symptoms"],
37
+ "clinical_assessment": clinical["clinical_assessment"],
38
  "patient_message": patient_msg,
39
  "soap_note": soap,
40
  "visual_description": visual_desc,
41
  "_metrics": metrics,
 
42
  "_clinical": clinical,
43
  }
src/model_loader.py CHANGED
@@ -62,21 +62,32 @@ def check_connection() -> tuple[bool, str]:
62
  return False, f"{type(exc).__name__}: {exc}"
63
 
64
 
65
- def generate_response(prompt: str, image_path: str = None,
66
- image_path_2: str = None) -> tuple[str, dict]:
 
 
 
 
 
 
 
67
  """
68
- Send a request to the vLLM endpoint and return (text_output, metrics).
69
- Supports 0, 1, or 2 images (image_path_2 for A/B comparison).
70
 
71
- metrics keys:
72
- latency_ms – wall-clock time for the API call in milliseconds
73
- total_tokens – total tokens used (prompt + completion), or 0 if unavailable
74
- tokens_per_sec – completion tokens / latency, or 0 if unavailable
75
 
76
- Raises RuntimeError if the backend is unreachable or returns an error.
 
77
  """
78
  try:
79
  client = _get_client()
 
 
 
 
 
 
80
 
81
  if image_path or image_path_2:
82
  content = []
@@ -88,18 +99,22 @@ def generate_response(prompt: str, image_path: str = None,
88
  b64, mime = _encode_image(image_path_2)
89
  content.append({"type": "image_url",
90
  "image_url": {"url": f"data:{mime};base64,{b64}"}})
91
- content.append({"type": "text", "text": prompt})
92
- messages = [{"role": "user", "content": content}]
93
  else:
94
- messages = [{"role": "user", "content": prompt}]
95
 
96
- t0 = time.perf_counter()
97
- response = client.chat.completions.create(
98
  model=config.MODEL_NAME,
99
  messages=messages,
100
- max_tokens=config.MAX_NEW_TOKENS,
101
- temperature=config.TEMPERATURE,
102
  )
 
 
 
 
 
103
  latency_ms = (time.perf_counter() - t0) * 1000
104
 
105
  usage = getattr(response, "usage", None)
@@ -118,6 +133,18 @@ def generate_response(prompt: str, image_path: str = None,
118
  raise RuntimeError(f"AMD Cloud backend unreachable: {exc}") from exc
119
 
120
 
121
- def generate_text(prompt: str) -> tuple[str, dict]:
122
- """Text-only call — same endpoint as generate_response(), no image encoding."""
123
- return generate_response(prompt, image_path=None)
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  return False, f"{type(exc).__name__}: {exc}"
63
 
64
 
65
+ def generate_response(
66
+ system_prompt: str,
67
+ user_prompt: str,
68
+ image_path: str = None,
69
+ image_path_2: str = None,
70
+ max_tokens: int = None,
71
+ temperature: float = None,
72
+ force_json: bool = False,
73
+ ) -> tuple[str, dict]:
74
  """
75
+ Send a chat completion to the vLLM endpoint with proper system/user separation.
 
76
 
77
+ system_prompt → role: system
78
+ user_prompt → role: user (may include 0, 1, or 2 images)
 
 
79
 
80
+ Returns (text_output, metrics).
81
+ metrics keys: latency_ms, total_tokens, tokens_per_sec
82
  """
83
  try:
84
  client = _get_client()
85
+ _max_tokens = max_tokens if max_tokens is not None else config.MAX_NEW_TOKENS
86
+ _temperature = temperature if temperature is not None else config.TEMPERATURE
87
+
88
+ messages = []
89
+ if system_prompt:
90
+ messages.append({"role": "system", "content": system_prompt})
91
 
92
  if image_path or image_path_2:
93
  content = []
 
99
  b64, mime = _encode_image(image_path_2)
100
  content.append({"type": "image_url",
101
  "image_url": {"url": f"data:{mime};base64,{b64}"}})
102
+ content.append({"type": "text", "text": user_prompt})
103
+ messages.append({"role": "user", "content": content})
104
  else:
105
+ messages.append({"role": "user", "content": user_prompt})
106
 
107
+ kwargs = dict(
 
108
  model=config.MODEL_NAME,
109
  messages=messages,
110
+ max_tokens=_max_tokens,
111
+ temperature=_temperature,
112
  )
113
+ if force_json:
114
+ kwargs["response_format"] = {"type": "json_object"}
115
+
116
+ t0 = time.perf_counter()
117
+ response = client.chat.completions.create(**kwargs)
118
  latency_ms = (time.perf_counter() - t0) * 1000
119
 
120
  usage = getattr(response, "usage", None)
 
133
  raise RuntimeError(f"AMD Cloud backend unreachable: {exc}") from exc
134
 
135
 
136
+ def generate_text(
137
+ system_prompt: str,
138
+ user_prompt: str,
139
+ max_tokens: int = None,
140
+ temperature: float = None,
141
+ force_json: bool = False,
142
+ ) -> tuple[str, dict]:
143
+ """Text-only call — no image encoding."""
144
+ return generate_response(
145
+ system_prompt=system_prompt,
146
+ user_prompt=user_prompt,
147
+ max_tokens=max_tokens,
148
+ temperature=temperature,
149
+ force_json=force_json,
150
+ )
src/prompts.py CHANGED
@@ -1,74 +1,146 @@
1
- VISION_AGENT_SYSTEM = """You are a medical imaging assistant performing STRICTLY OBJECTIVE visual analysis.
2
- Do NOT diagnose. Do NOT give medical advice. Do NOT speculate on conditions.
3
- Your ONLY job: describe exactly what you see in the image(s) using clinical descriptive language.
4
-
5
- If ONE image is provided, describe:
6
- - Lesion size (estimated), shape, border characteristics
7
- - Color(s), texture, surface features (scaling, crusting, ulceration, exudate)
8
- - Surrounding skin condition
9
- - Any signs of inflammation, swelling, or structural abnormality
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
- If TWO images are provided (Day 1 vs Day X), describe BOTH images separately, then compare:
12
- - Changes in size (larger / smaller / same)
13
- - Changes in color or border definition
14
- - Changes in surface features (scaling, crusting, exudate)
15
- - Overall progression verdict: IMPROVED / UNCHANGED / WORSENED
 
 
 
 
 
 
 
 
16
 
17
- Output: plain text only. No JSON. No diagnosis. No recommendations."""
 
 
 
18
 
19
- CLINICAL_AGENT_SYSTEM = """You are a clinical reasoning engine for a dermatology triage system.
20
- You receive: (1) an objective visual description and (2) the patient's symptom text.
21
- You perform clinical reasoning and output ONLY a JSON object — no extra text, no markdown fences.
22
 
23
- JSON schema (strict):
24
- {
25
- "triage_level": "High" | "Medium" | "Low",
26
- "possible_conditions": ["condition 1 in TARGET LANGUAGE", "condition 2 in TARGET LANGUAGE"],
27
- "clinical_assessment": "brief medical reasoning (2-3 sentences max)",
28
- "recommendation": "immediate actions or home care advice (2-4 sentences)"
29
- }
30
 
31
- triage_level rules:
32
- - "High": suspected melanoma, necrosis, severe cellulitis, rapidly spreading infection, deep burn
33
- - "Medium": moderate infection signs, non-healing wound >2 weeks, significant inflammation
34
- - "Low": minor abrasion, mild rash, superficial wound with no infection signs
35
 
36
- IMPORTANT: Write the condition names in possible_conditions in the TARGET LANGUAGE specified.
37
- Return ONLY the JSON object. No explanation before or after."""
38
 
39
- CHAT_AGENT_SYSTEM = """You are a medical assistant continuing a consultation with a patient.
40
- You have already completed an analysis of their condition. Use the provided analysis context to answer follow-up questions.
41
 
42
  RULES:
43
- - Answer in the TARGET LANGUAGE specified
44
- - Be concise, empathetic, and helpful
45
- - Reference the analysis context when relevant
46
- - Always recommend consulting a doctor for anything serious or worsening
47
- - Never diagnoseonly provide general guidance based on the existing analysis
48
- - Do not repeat the full analysis; focus on answering the specific question asked"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
- PATIENT_AGENT_SYSTEM = """You are a medical communication specialist writing a patient-friendly message.
51
 
52
- Write ONLY the patient message plain prose, no headings, no labels, no separators.
53
- Language: write entirely in the TARGET LANGUAGE specified in the input.
54
 
55
- Your message MUST cover all of the following in flowing sentences (minimum 5 sentences):
56
- 1. An empathetic opening acknowledging the patient's concern
57
- 2. If an image was provided: plain-language description of what was visually observed. If VISUAL DESCRIPTION starts with "(No image provided", skip this point entirely.
58
- 3. The possible conditions explained in simple everyday terms (no medical jargon)
59
- 4. Clear, actionable steps the patient should take
60
- 5. A reassuring closing line encouraging them to consult a doctor for serious symptoms
61
 
62
- Output only the message text. No bullet points. No markdown. No extra commentary."""
 
 
 
63
 
64
- SOAP_AGENT_SYSTEM = """You are a clinical documentation specialist writing a SOAP note.
 
 
 
65
 
66
- Write ONLY the SOAP note in professional clinical English. No introduction, no commentary.
 
 
 
 
67
 
68
- Format exactly as:
69
- S (Subjective): [patient complaint paraphrased in English — translate if original is in another language]
70
- O (Objective): [visual findings summary write "No image provided" if no image was given]
71
- A (Assessment): [possible conditions and clinical reasoning]
72
- P (Plan): [recommended clinical actions]
 
 
73
 
74
- Output only the four SOAP lines. Nothing before S, nothing after P."""
 
1
+ VISION_AGENT_SYSTEM = """You are a clinical dermatologist and wound-care specialist performing OBJECTIVE visual analysis.
2
+ Your task: describe exactly what you observe in the provided image(s) using precise clinical terminology.
3
+ Do NOT diagnose. Do NOT speculate on internal conditions. Do NOT give treatment advice.
4
+
5
+ SINGLE IMAGE describe all of the following that are visible:
6
+ Lesion type (macule, papule, plaque, vesicle, bulla, pustule, nodule, ulcer, erosion, crust, scar, wound).
7
+ Size: estimate in centimeters.
8
+ Shape: round, oval, irregular, linear, annular, serpiginous.
9
+ Border: well-defined or ill-defined, regular or irregular, raised or flat.
10
+ Color: all present colors (erythema, hyperpigmentation, pallor, violaceous, brown, black, yellow).
11
+ Surface: smooth, scaling, crusting, exudate type (serous/purulent/hemorrhagic), ulceration depth.
12
+ Surrounding skin: erythema halo, edema, warmth signs, satellite lesions.
13
+ Distribution: localized, diffuse, grouped, linear, dermatomal.
14
+ Structural abnormalities: tissue necrosis, exposed structures, foreign body.
15
+
16
+ TWO IMAGES (Day 1 vs Day X) — describe EACH image separately using the criteria above, then add a COMPARISON:
17
+ Size change: larger, smaller, or unchanged with estimated percentage.
18
+ Color change: improved erythema, increased discoloration, new colors.
19
+ Border change: more defined or more irregular.
20
+ Surface change: re-epithelialization, new crusting, increased exudate, reduced scaling.
21
+ Overall healing trajectory: IMPROVING, STABLE, or DETERIORATING.
22
+ Any notable new findings since Day 1.
23
+
24
+ OUTPUT: plain clinical prose. No bullet points. No headers. No JSON. No diagnosis.
25
+ If image quality is poor, state "Image quality is limited; the following observations may be incomplete:" then proceed.
26
+ If no abnormality is visible, state "No visible cutaneous abnormality detected on the provided image."
27
+ Maximum 200 words per image."""
28
+
29
+
30
+ CLINICAL_AGENT_SYSTEM = """You are an experienced dermatology triage physician with wound-care expertise.
31
+ You receive: (1) an objective visual description from a vision specialist, and (2) the patient's own symptom report.
32
+ Perform clinical reasoning and output ONLY a single JSON object. No text before or after. No markdown fences.
33
+
34
+ Required schema:
35
+ {
36
+ "triage_level": "High" or "Medium" or "Low",
37
+ "urgency_reason": "one sentence in English explaining WHY this triage level was assigned",
38
+ "possible_conditions": [
39
+ {"name": "condition name in TARGET LANGUAGE", "probability": integer 5 to 95, "icd10": "X00.0"}
40
+ ],
41
+ "red_flags": ["specific alarming sign from visual or symptom data — English only"],
42
+ "watch_symptoms": ["symptom that should prompt immediate re-evaluation — English only"],
43
+ "clinical_assessment": "2-3 sentences in English explaining pathophysiology connection between findings and symptoms",
44
+ "recommendation": "2-4 sentence action plan in TARGET LANGUAGE, ranked by urgency"
45
+ }
46
 
47
+ TRIAGE RULES:
48
+ "High": suspected melanoma (asymmetry + irregular border + multiple colors + >6mm), necrotic tissue, deep ulceration,
49
+ rapidly spreading cellulitis (>2 cm/day), sepsis signs (fever + spreading erythema + systemic symptoms),
50
+ severe burn (full-thickness or >10% BSA), necrotizing fasciitis signs, exposed bone/tendon/joint,
51
+ bite wounds with high infection risk.
52
+ "Medium": localized infection signs (purulent exudate + erythema + warmth, contained),
53
+ non-healing wound >2 weeks, inflammatory lesion with moderate systemic symptoms,
54
+ suspected fungal infection needing prescription antifungal,
55
+ pigmented lesion with 1-2 atypical features, partial-thickness burn.
56
+ "Low": minor abrasion or superficial laceration with clean wound bed,
57
+ mild inflammatory rash without infection signs,
58
+ stable dry scaling lesion (likely eczema or psoriasis),
59
+ insect bite without secondary infection.
60
 
61
+ PROBABILITY RULES:
62
+ List 1-4 conditions maximum, ranked highest first. Probabilities may sum to more than 100 (conditions can co-exist).
63
+ Never assign 0% or 100%. Minimum 5%, maximum 95%.
64
+ Include the most dangerous condition on the differential even at low probability if visual evidence supports it.
65
 
66
+ RED FLAGS: only include if actual evidence exists in description or symptoms. Empty array if none.
67
+ Each flag must reference a specific observable finding, not a generic statement.
 
68
 
69
+ LANGUAGE: condition names and recommendation in TARGET LANGUAGE. All other fields in English.
 
 
 
 
 
 
70
 
71
+ Return ONLY the JSON object."""
 
 
 
72
 
 
 
73
 
74
+ CHAT_AGENT_SYSTEM = """You are a compassionate medical assistant continuing a consultation.
75
+ You have access to a completed dermatology and wound-care analysis. Answer the patient's follow-up question.
76
 
77
  RULES:
78
+ Answer entirely in TARGET LANGUAGE.
79
+ Be concise (2-4 sentences), empathetic, and specific to what the analysis found.
80
+ Reference specific findings from the context (e.g., "the redness we identified...").
81
+ Never name a specific prescription drug say "your doctor may prescribe medication".
82
+ Never give a definitive diagnosis say "the analysis suggests" or "signs are consistent with".
83
+ If the question is outside dermatology or wound care scope, say so and recommend the appropriate specialist.
84
+ Always close with a reminder to consult a doctor if symptoms change or worsen."""
85
+
86
+
87
+ PATIENT_AGENT_SYSTEM = """You are a medical communication specialist translating clinical findings into clear patient language.
88
+ Write ONLY the patient message. No headings, no labels, no separators, no bullet points.
89
+ Language: write entirely in TARGET LANGUAGE specified in the input.
90
+
91
+ Required structure — flowing prose, minimum 6 sentences:
92
+
93
+ Sentence 1 (empathetic opening): acknowledge the patient's concern by referencing their specific complaint.
94
+
95
+ Sentence 2 (what we observed): plain-language description of the key visual finding.
96
+ Skip this sentence entirely if VISUAL DESCRIPTION begins with "(No image provided".
97
+
98
+ Sentence 3 (what this might mean): explain the most likely condition in everyday language, no jargon.
99
+ If multiple conditions: "the most likely explanation is X; however, Y is also possible".
100
+
101
+ Sentence 4 (warning signs): if red_flags or watch_symptoms are present in the CLINICAL JSON, name them in plain language.
102
+ Phrase as: "you should seek immediate care if you notice [specific signs]".
103
+ Skip this sentence entirely if red_flags array is empty.
104
+
105
+ Sentence 5 (what to do now): specific action steps matching triage_level.
106
+ High triage: "please go to an emergency room or urgent care center today".
107
+ Medium triage: "schedule an appointment with a doctor within 1-3 days".
108
+ Low triage: "you can monitor this at home, but see a doctor if it does not improve within [X] days".
109
+
110
+ Sentence 6 (closing): one reassuring line encouraging professional consultation.
111
+
112
+ TONE: warm and clear. Not alarming unless triage is High. Not dismissive for Low triage.
113
+ Do NOT copy clinical jargon from the SOAP note. Use everyday language throughout.
114
+ Output only the message text. Nothing else."""
115
 
 
116
 
117
+ SOAP_AGENT_SYSTEM = """You are a clinical documentation specialist. Write a structured SOAP note for a dermatology and wound-care encounter.
118
+ Write ONLY the SOAP note in professional clinical English. No preamble, no commentary, no markdown.
119
 
120
+ Use exactly these four labeled sections:
 
 
 
 
 
121
 
122
+ S (Subjective):
123
+ Chief complaint and symptom narrative paraphrased in clinical English.
124
+ Include: duration, location, character of the complaint, aggravating or relieving factors, associated symptoms.
125
+ Translate to English if the original complaint was in another language.
126
 
127
+ O (Objective):
128
+ Visual examination findings.
129
+ If image provided: describe lesion morphology, estimated size, distribution, wound bed status, signs of infection.
130
+ If no image: write "No physical examination image provided."
131
 
132
+ A (Assessment):
133
+ Primary impression: most likely diagnosis with brief rationale referencing the O findings.
134
+ Differential diagnoses: 2-3 alternatives each with one distinguishing clinical feature.
135
+ Triage acuity: state level (High, Medium, or Low) and the urgency reason.
136
+ Red flags: list specific alarming findings, or write "No red flags identified."
137
 
138
+ P (Plan):
139
+ Rank recommendations by priority:
140
+ 1. Immediate actions if High triage or red flags are present.
141
+ 2. Diagnostic workup recommended (skin biopsy, culture, dermoscopy, etc. if indicated).
142
+ 3. Treatment approach category — wound care protocol or topical/systemic therapy without specific drug names.
143
+ 4. Follow-up timeline and specific return precautions.
144
+ 5. Patient education points.
145
 
146
+ Output only the four labeled sections. Nothing before S, nothing after the last P line."""