mou11 commited on
Commit
084c003
Β·
verified Β·
1 Parent(s): 104ab35

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +384 -0
app.py ADDED
@@ -0,0 +1,384 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import time
4
+ import re
5
+ import uuid
6
+ from datetime import datetime
7
+
8
+ from groq import Groq
9
+ from transformers import pipeline
10
+ from bert_score import score as bert_score
11
+ from rouge_score import rouge_scorer
12
+ from reportlab.lib.pagesizes import A4
13
+ from reportlab.lib.styles import getSampleStyleSheet, ParagraphStyle
14
+ from reportlab.lib.units import inch
15
+ from reportlab.lib import colors
16
+ from reportlab.platypus import SimpleDocTemplate, Paragraph, Spacer, Table, TableStyle, HRFlowable
17
+ from reportlab.lib.enums import TA_CENTER
18
+ import gradio as gr
19
+
20
+
21
+ # ── Groq Client ──────────────────────────────────────────────
22
+ class GroqClientManager:
23
+ def __init__(self):
24
+ self.keys = []
25
+ self.current_index = 0
26
+ self._load_keys()
27
+
28
+ def _load_keys(self):
29
+ for i in range(1, 6):
30
+ key = os.environ.get(f"GROQ_KEY_{i}")
31
+ if key:
32
+ self.keys.append(key)
33
+ if not self.keys:
34
+ raise ValueError("No Groq API keys found. Add GROQ_KEY_1 to Space secrets.")
35
+ print(f"Loaded {len(self.keys)} Groq key(s)")
36
+
37
+ def get_client(self):
38
+ return Groq(api_key=self.keys[self.current_index])
39
+
40
+ def rotate(self):
41
+ self.current_index = (self.current_index + 1) % len(self.keys)
42
+
43
+ def chat(self, messages, max_tokens=1500, temperature=0.3):
44
+ for _ in range(len(self.keys)):
45
+ try:
46
+ client = self.get_client()
47
+ response = client.chat.completions.create(
48
+ model="llama-3.3-70b-versatile",
49
+ messages=messages,
50
+ max_tokens=max_tokens,
51
+ temperature=temperature
52
+ )
53
+ return response.choices[0].message.content
54
+ except Exception as e:
55
+ if "429" in str(e) or "rate_limit" in str(e).lower():
56
+ self.rotate()
57
+ time.sleep(2)
58
+ else:
59
+ raise e
60
+ raise RuntimeError("All Groq API keys exhausted.")
61
+
62
+
63
+ groq_manager = GroqClientManager()
64
+
65
+ print("Loading NLI model...")
66
+ nli_pipeline = pipeline(
67
+ "text-classification",
68
+ model="cross-encoder/nli-deberta-v3-base"
69
+ )
70
+ print("NLI model loaded.")
71
+
72
+
73
+ # ── Report Prompts ────────────────────────────────────────────
74
+ REPORT_PROMPTS = {
75
+ "radiology": """You are a radiologist writing a formal radiology report.
76
+ Given the following patient data, generate a structured radiology report.
77
+
78
+ Patient Data:
79
+ {patient_data}
80
+
81
+ Generate a report with these exact sections:
82
+ CLINICAL INDICATION:
83
+ TECHNIQUE:
84
+ FINDINGS:
85
+ IMPRESSION:
86
+
87
+ Be specific, clinical, and only use information provided. Do not invent findings.""",
88
+
89
+ "discharge": """You are a hospital physician writing a discharge summary.
90
+ Given the following patient data, generate a structured discharge summary.
91
+
92
+ Patient Data:
93
+ {patient_data}
94
+
95
+ Generate a report with these exact sections:
96
+ PATIENT INFORMATION:
97
+ ADMISSION DIAGNOSIS:
98
+ HOSPITAL COURSE:
99
+ DISCHARGE DIAGNOSIS:
100
+ DISCHARGE MEDICATIONS:
101
+ FOLLOW-UP INSTRUCTIONS:
102
+
103
+ Only use information provided. Do not invent medications or diagnoses.""",
104
+
105
+ "lab": """You are a clinical pathologist writing a laboratory report.
106
+ Given the following patient data, generate a structured lab report.
107
+
108
+ Patient Data:
109
+ {patient_data}
110
+
111
+ Generate a report with these exact sections:
112
+ TEST ORDERED:
113
+ SPECIMEN:
114
+ RESULTS:
115
+ REFERENCE RANGES:
116
+ INTERPRETATION:
117
+ RECOMMENDATION:
118
+
119
+ Only use information provided. Do not invent lab values."""
120
+ }
121
+
122
+
123
+ # ── Core Functions ────────────────────────────────────────────
124
+ def generate_report(patient_data, report_type):
125
+ patient_data_str = "\n".join([f"{k}: {v}" for k, v in patient_data.items()])
126
+ prompt = REPORT_PROMPTS[report_type].format(patient_data=patient_data_str)
127
+ response = groq_manager.chat([{"role": "user", "content": prompt}])
128
+ return {
129
+ "report_type": report_type,
130
+ "report_text": response.strip(),
131
+ "patient_data": patient_data,
132
+ "generated_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
133
+ }
134
+
135
+
136
+ def extract_sentences(text):
137
+ sentences = re.split(r'(?<=[.!?])\s+', text.strip())
138
+ return [s.strip() for s in sentences
139
+ if len(s.strip()) > 20 and not s.strip().isupper() and ":" not in s[:25]]
140
+
141
+
142
+ def check_hallucination(report):
143
+ source_text = " ".join([f"{k} is {v}." for k, v in report["patient_data"].items()])
144
+ sentences = extract_sentences(report["report_text"])
145
+ if not sentences:
146
+ return {"error": "No checkable sentences found."}
147
+
148
+ results = []
149
+ hallucination_count = 0
150
+
151
+ for sentence in sentences:
152
+ nli_input = f"{source_text} [SEP] {sentence}"
153
+ prediction = nli_pipeline(nli_input, truncation=True, max_length=512)
154
+ label = prediction[0]["label"].lower()
155
+ score = prediction[0]["score"]
156
+
157
+ if "entail" in label:
158
+ status = "supported"
159
+ elif "contradict" in label:
160
+ status = "hallucinated"
161
+ hallucination_count += 1
162
+ else:
163
+ status = "unverified"
164
+ if score > 0.80:
165
+ hallucination_count += 0.5
166
+
167
+ results.append({"sentence": sentence, "status": status, "confidence": round(score, 4)})
168
+
169
+ total = len(sentences)
170
+ hallucination_rate = round(hallucination_count / total, 4) if total > 0 else 0
171
+
172
+ return {
173
+ "report_type": report["report_type"],
174
+ "total_claims": total,
175
+ "hallucination_rate": hallucination_rate,
176
+ "safety_score": round(1 - hallucination_rate, 4),
177
+ "claim_results": results
178
+ }
179
+
180
+
181
+ def evaluate_report(report):
182
+ reference = " ".join([f"{k} is {v}." for k, v in report["patient_data"].items()])
183
+ hypothesis = report["report_text"]
184
+
185
+ P, R, F1 = bert_score([hypothesis], [reference], lang="en", verbose=False)
186
+ scorer = rouge_scorer.RougeScorer(["rouge1", "rouge2", "rougeL"], use_stemmer=True)
187
+ rouge_scores = scorer.score(reference, hypothesis)
188
+
189
+ return {
190
+ "bertscore": {
191
+ "precision": round(P[0].item(), 4),
192
+ "recall": round(R[0].item(), 4),
193
+ "f1": round(F1[0].item(), 4)
194
+ },
195
+ "rouge": {
196
+ "rouge1": round(rouge_scores["rouge1"].fmeasure, 4),
197
+ "rouge2": round(rouge_scores["rouge2"].fmeasure, 4),
198
+ "rougeL": round(rouge_scores["rougeL"].fmeasure, 4)
199
+ }
200
+ }
201
+
202
+
203
+ def create_fhir_report(report, hallucination_result):
204
+ report_type_codes = {
205
+ "radiology": {"code": "18748-4", "display": "Diagnostic imaging study"},
206
+ "discharge": {"code": "18842-5", "display": "Discharge summary"},
207
+ "lab": {"code": "11502-2", "display": "Laboratory report"}
208
+ }
209
+ code_info = report_type_codes.get(report["report_type"], {"code": "unknown", "display": "Clinical Report"})
210
+
211
+ return {
212
+ "resourceType": "DiagnosticReport",
213
+ "id": str(uuid.uuid4()),
214
+ "status": "final",
215
+ "category": [{"coding": [{"system": "http://terminology.hl7.org/CodeSystem/v2-0074",
216
+ "code": code_info["code"],
217
+ "display": code_info["display"]}]}],
218
+ "code": {"text": f"{report['report_type'].capitalize()} Report"},
219
+ "subject": {"reference": f"Patient/{str(uuid.uuid4())}",
220
+ "display": report["patient_data"].get("name", "Unknown")},
221
+ "effectiveDateTime": report["generated_at"],
222
+ "issued": datetime.now().strftime("%Y-%m-%dT%H:%M:%SZ"),
223
+ "conclusion": report["report_text"],
224
+ "extension": [
225
+ {"url": "https://medical-ai-portfolio.dev/fhir/hallucination-score",
226
+ "valueDecimal": hallucination_result.get("hallucination_rate", 0)},
227
+ {"url": "https://medical-ai-portfolio.dev/fhir/safety-score",
228
+ "valueDecimal": hallucination_result.get("safety_score", 1)},
229
+ {"url": "https://medical-ai-portfolio.dev/fhir/total-claims-checked",
230
+ "valueInteger": hallucination_result.get("total_claims", 0)}
231
+ ]
232
+ }
233
+
234
+
235
+ def export_pdf(report, hallucination_result, eval_result):
236
+ output_path = f"/tmp/{report['report_type']}_report_{datetime.now().strftime('%Y%m%d_%H%M%S')}.pdf"
237
+ doc = SimpleDocTemplate(output_path, pagesize=A4,
238
+ rightMargin=0.75*inch, leftMargin=0.75*inch,
239
+ topMargin=0.75*inch, bottomMargin=0.75*inch)
240
+ styles = getSampleStyleSheet()
241
+
242
+ title_style = ParagraphStyle("title", parent=styles["Title"], fontSize=16, spaceAfter=6, alignment=TA_CENTER)
243
+ subtitle_style= ParagraphStyle("subtitle",parent=styles["Normal"], fontSize=9, spaceAfter=12, alignment=TA_CENTER, textColor=colors.grey)
244
+ section_style = ParagraphStyle("section", parent=styles["Heading2"],fontSize=11, spaceBefore=12,spaceAfter=4, textColor=colors.HexColor("#1a1a2e"))
245
+ body_style = ParagraphStyle("body", parent=styles["Normal"], fontSize=9, spaceAfter=6, leading=14)
246
+ label_style = ParagraphStyle("label", parent=styles["Normal"], fontSize=9, textColor=colors.grey)
247
+
248
+ story = []
249
+ story.append(Paragraph("Medical Report Generator", title_style))
250
+ story.append(Paragraph(f"Medical AI Portfolio β€” Project 4 | Generated: {report['generated_at']}", subtitle_style))
251
+ story.append(HRFlowable(width="100%", thickness=1, color=colors.HexColor("#1a1a2e")))
252
+ story.append(Spacer(1, 12))
253
+ story.append(Paragraph(f"{report['report_type'].upper()} REPORT", section_style))
254
+ story.append(Spacer(1, 6))
255
+
256
+ patient_table_data = [[Paragraph(f"<b>{k.replace('_',' ').title()}</b>", label_style),
257
+ Paragraph(str(v), body_style)]
258
+ for k, v in report["patient_data"].items()]
259
+ patient_table = Table(patient_table_data, colWidths=[1.8*inch, 4.5*inch])
260
+ patient_table.setStyle(TableStyle([
261
+ ("BACKGROUND", (0,0),(0,-1), colors.HexColor("#f0f0f0")),
262
+ ("GRID", (0,0),(-1,-1),0.5, colors.lightgrey),
263
+ ("VALIGN", (0,0),(-1,-1),"TOP"),
264
+ ("PADDING", (0,0),(-1,-1),6),
265
+ ]))
266
+ story.append(Paragraph("Patient Information", section_style))
267
+ story.append(patient_table)
268
+ story.append(Spacer(1, 12))
269
+
270
+ story.append(Paragraph("Generated Report", section_style))
271
+ story.append(HRFlowable(width="100%", thickness=0.5, color=colors.lightgrey))
272
+ story.append(Spacer(1, 6))
273
+ for line in report["report_text"].split("\n"):
274
+ if line.strip():
275
+ if line.strip().isupper() or line.strip().endswith(":"):
276
+ story.append(Paragraph(f"<b>{line.strip()}</b>", body_style))
277
+ else:
278
+ story.append(Paragraph(line.strip(), body_style))
279
+ story.append(Spacer(1, 12))
280
+
281
+ story.append(Paragraph("Quality Assessment", section_style))
282
+ quality_data = [
283
+ ["Metric", "Value"],
284
+ ["Total Claims Checked", str(hallucination_result.get("total_claims", 0))],
285
+ ["Hallucination Rate", str(hallucination_result.get("hallucination_rate", 0))],
286
+ ["Safety Score", str(hallucination_result.get("safety_score", 0))],
287
+ ["BERTScore F1", str(eval_result["bertscore"]["f1"])],
288
+ ["ROUGE-1", str(eval_result["rouge"]["rouge1"])],
289
+ ["ROUGE-2", str(eval_result["rouge"]["rouge2"])],
290
+ ["ROUGE-L", str(eval_result["rouge"]["rougeL"])],
291
+ ]
292
+ quality_table = Table(quality_data, colWidths=[2.5*inch, 2*inch])
293
+ quality_table.setStyle(TableStyle([
294
+ ("BACKGROUND", (0,0),(-1,0), colors.HexColor("#1a1a2e")),
295
+ ("TEXTCOLOR", (0,0),(-1,0), colors.white),
296
+ ("FONTNAME", (0,0),(-1,0), "Helvetica-Bold"),
297
+ ("GRID", (0,0),(-1,-1),0.5, colors.lightgrey),
298
+ ("BACKGROUND", (0,1),(-1,-1),colors.HexColor("#f9f9f9")),
299
+ ("PADDING", (0,0),(-1,-1),6),
300
+ ]))
301
+ story.append(quality_table)
302
+ story.append(Spacer(1, 12))
303
+ story.append(HRFlowable(width="100%", thickness=0.5, color=colors.lightgrey))
304
+ story.append(Paragraph("Generated by Medical Report Generator | Moumita Roy | Medical AI Portfolio Project 4", subtitle_style))
305
+ doc.build(story)
306
+ return output_path
307
+
308
+
309
+ # ── Pipeline ──────────────────────────────────────────────────
310
+ def run_pipeline(name, age, sex, chief_complaint, vitals, history, imaging, labs, report_type):
311
+ patient_data = {
312
+ "name": name, "age": age, "sex": sex,
313
+ "chief_complaint": chief_complaint, "vitals": vitals,
314
+ "history": history, "imaging": imaging, "labs": labs
315
+ }
316
+ patient_data = {k: v for k, v in patient_data.items() if v and str(v).strip()}
317
+
318
+ report = generate_report(patient_data, report_type)
319
+ h_result = check_hallucination(report)
320
+ e_result = evaluate_report(report)
321
+ fhir = create_fhir_report(report, h_result)
322
+ pdf_path = export_pdf(report, h_result, e_result)
323
+
324
+ claim_breakdown = "\n".join([
325
+ f"[{c['status'].upper():12}] ({c['confidence']}) {c['sentence'][:90]}..."
326
+ for c in h_result.get("claim_results", [])
327
+ ])
328
+
329
+ hallucination_summary = f"""Total Claims Checked : {h_result['total_claims']}
330
+ Hallucination Rate : {h_result['hallucination_rate']}
331
+ Safety Score : {h_result['safety_score']}
332
+
333
+ Claim-level Breakdown:
334
+ {claim_breakdown}"""
335
+
336
+ metrics_summary = f"""BERTScore β€” P: {e_result['bertscore']['precision']} R: {e_result['bertscore']['recall']} F1: {e_result['bertscore']['f1']}
337
+ ROUGE-1 : {e_result['rouge']['rouge1']}
338
+ ROUGE-2 : {e_result['rouge']['rouge2']}
339
+ ROUGE-L : {e_result['rouge']['rougeL']}"""
340
+
341
+ return (
342
+ report["report_text"],
343
+ hallucination_summary,
344
+ metrics_summary,
345
+ json.dumps(fhir, indent=2),
346
+ pdf_path
347
+ )
348
+
349
+
350
+ # ── Gradio UI ─────────────────────────────────────────────────
351
+ with gr.Blocks(title="Medical Report Generator") as demo:
352
+ gr.Markdown("# Medical Report Generator")
353
+ gr.Markdown("Medical AI Portfolio β€” Project 4 | Moumita Roy | [GitHub](https://github.com/moumitaroy19/medical-report-generator)")
354
+
355
+ with gr.Row():
356
+ with gr.Column():
357
+ gr.Markdown("### Patient Information")
358
+ name = gr.Textbox(label="Full Name", value="John Doe")
359
+ age = gr.Textbox(label="Age", value="58")
360
+ sex = gr.Textbox(label="Sex", value="Male")
361
+ chief_complaint = gr.Textbox(label="Chief Complaint", value="Chest pain and shortness of breath for 2 days")
362
+ vitals = gr.Textbox(label="Vitals", value="BP 145/90, HR 88, RR 18, Temp 37.1C, SpO2 96%")
363
+ history = gr.Textbox(label="Medical History", value="Hypertension, Type 2 Diabetes, smoker for 20 years")
364
+ imaging = gr.Textbox(label="Imaging", value="Chest X-ray ordered, mild cardiomegaly noted")
365
+ labs = gr.Textbox(label="Lab Results", value="WBC 11.2, HGB 13.4, Troponin 0.02, BNP 210")
366
+ report_type = gr.Dropdown(choices=["radiology", "discharge", "lab"],
367
+ value="radiology", label="Report Type")
368
+ submit_btn = gr.Button("Generate Report", variant="primary")
369
+
370
+ with gr.Column():
371
+ gr.Markdown("### Output")
372
+ report_output = gr.Textbox(label="Generated Report", lines=12)
373
+ hallucination_output = gr.Textbox(label="Hallucination Analysis", lines=10)
374
+ metrics_output = gr.Textbox(label="Evaluation Metrics", lines=6)
375
+ fhir_output = gr.Textbox(label="FHIR R4 JSON", lines=10)
376
+ pdf_output = gr.File(label="Download PDF Report")
377
+
378
+ submit_btn.click(
379
+ fn=run_pipeline,
380
+ inputs=[name, age, sex, chief_complaint, vitals, history, imaging, labs, report_type],
381
+ outputs=[report_output, hallucination_output, metrics_output, fhir_output, pdf_output]
382
+ )
383
+
384
+ demo.launch()