pangweijlu's picture
Add synthetic data generation for images and CGM
d133742 verified
"""
Diabetes Multi-Agent Pro β€” Multi-modal, Multi-agentic Analysis Framework
With Synthetic Data Generation
Architecture: Manager + 4 Specialist Agents + Synthetic Data Generator
"""
import gradio as gr
import numpy as np
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from PIL import Image, ImageDraw, ImageFilter
import io
import re
import urllib.request
import json
DS_SERVER = "https://datasets-server.huggingface.co"
# ═══════════════════════════════════════════════════════════════
# SYNTHETIC DATA GENERATION
# ═══════════════════════════════════════════════════════════════
def generate_synthetic_fundus(severity="Healthy"):
"""Generate a synthetic retinal fundus image with specified severity."""
img = Image.new("RGB", (256, 256), color=(20, 10, 10))
draw = ImageDraw.Draw(img)
disc_x, disc_y = 80, 128
draw.ellipse([disc_x-25, disc_y-25, disc_x+25, disc_y+25], fill=(180, 140, 100))
draw.ellipse([disc_x-10, disc_y-10, disc_x+10, disc_y+10], fill=(220, 200, 160))
for angle in np.linspace(0, 2*np.pi, 8, endpoint=False):
x, y = disc_x, disc_y
for step in range(30):
x += np.cos(angle + step*0.05) * 4
y += np.sin(angle + step*0.05) * 4
width = max(1, 4 - step//8)
color = (120 + step*2, 30, 30) if severity == "Healthy" else (140 + step, 40, 40)
draw.ellipse([x-width, y-width, x+width, y+width], fill=color)
if severity in ["Mild DR", "Moderate DR", "Severe DR", "Proliferate DR"]:
num_hemorrhages = {"Mild DR": 3, "Moderate DR": 8, "Severe DR": 15, "Proliferate DR": 25}[severity]
for _ in range(num_hemorrhages):
hx, hy = np.random.randint(30, 226), np.random.randint(30, 226)
size = np.random.randint(2, 8)
draw.ellipse([hx-size, hy-size, hx+size, hy+size], fill=(180, 20, 20))
if severity in ["Severe DR", "Proliferate DR"]:
for _ in range(5):
nx, ny = np.random.randint(50, 200), np.random.randint(50, 200)
draw.line([(nx, ny), (nx+20, ny+10)], fill=(200, 50, 50), width=2)
img = img.filter(ImageFilter.GaussianBlur(radius=1))
return img
def generate_synthetic_cgm(duration_hours=72, pattern="stable"):
"""Generate synthetic CGM glucose readings."""
n_points = duration_hours * 12
t = np.arange(n_points)
patterns = {
"stable": (120, 30),
"volatile": (140, 60),
"hypo_prone": (100, 40),
"hyper_prone": (160, 50),
}
base, amplitude = patterns.get(pattern, (120, 30))
circadian = amplitude * (0.3 * np.sin(2 * np.pi * t / 288) + 0.2 * np.sin(2 * np.pi * t / 96) + 0.1 * np.random.randn(n_points))
glucose = base + circadian
glucose = np.clip(glucose, 40, 400)
return glucose.tolist()
def generate_synthetic_patient(profile="high_risk"):
"""Generate synthetic tabular patient features."""
profiles = {
"low_risk": {"preg": 0, "gluc": 95, "bp": 68, "skin": 18, "ins": 45, "bmi": 22.5, "ped": 0.25, "age": 28},
"moderate_risk": {"preg": 2, "gluc": 130, "bp": 72, "skin": 25, "ins": 85, "bmi": 28.0, "ped": 0.45, "age": 42},
"high_risk": {"preg": 5, "gluc": 165, "bp": 82, "skin": 35, "ins": 120, "bmi": 36.5, "ped": 0.65, "age": 52},
"elderly": {"preg": 0, "gluc": 148, "bp": 78, "skin": 28, "ins": 95, "bmi": 31.0, "ped": 0.55, "age": 68},
}
base = profiles.get(profile, profiles["moderate_risk"]).copy()
for key in base:
if key == "ped":
base[key] = round(base[key] + np.random.normal(0, 0.05), 2)
elif key == "bmi":
base[key] = round(base[key] + np.random.normal(0, 1.5), 1)
else:
base[key] = int(base[key] + np.random.normal(0, 3))
return base
# ═══════════════════════════════════════════════════════════════
# REAL DATA FETCHING
# ═══════════════════════════════════════════════════════════════
def fetch_rows(dataset, offset=0, limit=1):
url = f"{DS_SERVER}/rows?dataset={dataset}&config=default&split=train&offset={offset}&limit={limit}"
req = urllib.request.Request(url, headers={"User-Agent": "Mozilla/5.0"})
with urllib.request.urlopen(req, timeout=30) as resp:
data = json.loads(resp.read().decode("utf-8"))
return [r.get("row", {}) for r in data.get("rows", [])]
def fetch_image(row):
if not row:
return None
img_data = row.get("image")
if isinstance(img_data, dict):
src = img_data.get("src")
if src:
req = urllib.request.Request(src, headers={"User-Agent": "Mozilla/5.0"})
with urllib.request.urlopen(req, timeout=30) as resp:
return Image.open(io.BytesIO(resp.read()))
elif isinstance(img_data, bytes):
return Image.open(io.BytesIO(img_data))
return None
# ═══════════════════════════════════════════════════════════════
# SPECIALIST AGENTS
# ═══════════════════════════════════════════════════════════════
def image_agent(idx, use_synthetic=False, synthetic_severity="Healthy"):
try:
if use_synthetic:
img = generate_synthetic_fundus(synthetic_severity)
arr = np.array(img)
return {
"status": "ok",
"source": "synthetic",
"severity": synthetic_severity,
"mean_intensity": float(np.mean(arr)),
"std_intensity": float(np.std(arr)),
"size": img.size,
"recommendation": {
"Healthy": "Continue annual screening.",
"Mild DR": "Tighter glycemic control + 6-month follow-up.",
"Moderate DR": "Ophthalmology referral + quarterly monitoring.",
"Severe DR": "Urgent ophthalmology; pan-retinal photocoagulation may be needed.",
"Proliferate DR": "Immediate laser therapy or anti-VEGF required.",
}.get(synthetic_severity, "Manual expert review recommended.")
}
rows = fetch_rows("Rami/Diabetic_Retinopathy_Preprocessed_Dataset_256x256", int(idx), 1)
if not rows:
return {"status": "error", "message": "No image data returned"}
row = rows[0]
label = row.get("label", "Unknown")
img = fetch_image(row)
if img:
arr = np.array(img)
return {
"status": "ok",
"source": "real",
"severity": label,
"mean_intensity": float(np.mean(arr)),
"std_intensity": float(np.std(arr)),
"size": img.size,
"recommendation": {
"Healthy": "Continue annual screening.",
"Mild DR": "Tighter glycemic control + 6-month follow-up.",
"Moderate DR": "Ophthalmology referral + quarterly monitoring.",
"Severe DR": "Urgent ophthalmology; pan-retinal photocoagulation may be needed.",
"Proliferate DR": "Immediate laser therapy or anti-VEGF required.",
}.get(label, "Manual expert review recommended.")
}
return {"status": "partial", "severity": label, "message": "Image fetch failed"}
except Exception as e:
return {"status": "error", "message": str(e)}
def cgm_agent(idx, use_synthetic=False, synthetic_pattern="stable"):
try:
if use_synthetic:
values = generate_synthetic_cgm(72, synthetic_pattern)
arr = np.array(values, dtype=float)
m = len(arr)
mean_g = float(np.mean(arr))
std_g = float(np.std(arr))
in_range = np.sum((arr >= 70) & (arr <= 180)) / m * 100
hypo = np.sum(arr < 70) / m * 100
hyper = np.sum(arr > 180) / m * 100
cv = (std_g / mean_g * 100) if mean_g > 0 else 0.0
q1, q4 = np.mean(arr[:m // 4]), np.mean(arr[-m // 4:])
trend = "Rising" if q4 > q1 * 1.05 else "Falling" if q4 < q1 * 0.95 else "Stable"
note = (
"Significant hyperglycemia burden." if hyper > 30 else
"Notable hypoglycemia episodes." if hypo > 10 else
"Good glycemic control." if in_range > 70 and cv < 25 else
"Mixed control. Optimize lifestyle/medication."
)
return {
"status": "ok",
"source": "synthetic",
"patient_id": f"SYNTH_{synthetic_pattern.upper()}",
"readings": m,
"mean": round(mean_g, 1),
"std": round(std_g, 1),
"cv": round(cv, 1),
"time_in_range": round(in_range, 1),
"hypo_pct": round(hypo, 1),
"hyper_pct": round(hyper, 1),
"trend": trend,
"assessment": note,
}
from datasets import load_dataset
ds = load_dataset("elizah521/cgm_glucose_dataset", split="train")
idx = int(idx)
if not (0 <= idx < len(ds)):
return {"status": "error", "message": f"Index {idx} out of range"}
row = ds[idx]
arr = np.array(row["target"], dtype=float)
m = len(arr)
mean_g = float(np.mean(arr))
std_g = float(np.std(arr))
in_range = np.sum((arr >= 70) & (arr <= 180)) / m * 100
hypo = np.sum(arr < 70) / m * 100
hyper = np.sum(arr > 180) / m * 100
cv = (std_g / mean_g * 100) if mean_g > 0 else 0.0
q1, q4 = np.mean(arr[:m // 4]), np.mean(arr[-m // 4:])
trend = "Rising" if q4 > q1 * 1.05 else "Falling" if q4 < q1 * 0.95 else "Stable"
note = (
"Significant hyperglycemia burden." if hyper > 30 else
"Notable hypoglycemia episodes." if hypo > 10 else
"Good glycemic control." if in_range > 70 and cv < 25 else
"Mixed control. Optimize lifestyle/medication."
)
return {
"status": "ok",
"source": "real",
"patient_id": row["item_id"],
"readings": m,
"mean": round(mean_g, 1),
"std": round(std_g, 1),
"cv": round(cv, 1),
"time_in_range": round(in_range, 1),
"hypo_pct": round(hypo, 1),
"hyper_pct": round(hyper, 1),
"trend": trend,
"assessment": note,
}
except Exception as e:
return {"status": "error", "message": str(e)}
def text_agent(question):
try:
from datasets import load_dataset
ds = load_dataset("abdelhakimDZ/diabetes_QA_dataset", split="train")
keywords = set(re.findall(r"\b\w+\b", question.lower()))
scored = []
for i, row in enumerate(ds):
score = len(keywords.intersection(set(re.findall(r"\b\w+\b", row["question"].lower()))))
scored.append((score, i, row))
scored.sort(key=lambda x: x[0], reverse=True)
top = scored[:3]
if not top or top[0][0] == 0:
return {
"status": "partial",
"matches": [],
"general_advice": (
"General diabetes guidance:\n"
"β€’ Maintain regular glucose monitoring.\n"
"β€’ Balanced diet low in refined sugars, high in fiber.\n"
"β€’ β‰₯150 min/week moderate aerobic activity.\n"
"β€’ Target HbA1c <7% (individualized).\n"
"β€’ Annual eye, kidney, and foot exams."
)
}
return {
"status": "ok",
"matches": [{"question": row["question"], "answer": row["answer"]} for _, _, row in top],
}
except Exception as e:
return {"status": "error", "message": str(e)}
def risk_agent(preg, gluc, bp, skin, ins, bmi, ped, age):
risk = 0
factors = []
if int(gluc) >= 140:
risk += 1
factors.append("Elevated OGTT glucose")
if float(bmi) >= 30:
risk += 1
factors.append("Obesity (BMIβ‰₯30)")
if int(age) >= 45:
risk += 1
factors.append("Age β‰₯45")
if int(preg) >= 6:
risk += 1
factors.append("High parity")
if float(ped) >= 0.5:
risk += 1
factors.append("Strong family history")
level = "Low" if risk <= 1 else "Moderate" if risk <= 2.5 else "High"
rec = {
"Low": "Routine screening.",
"Moderate": "Lifestyle intervention + annual testing.",
"High": "Urgent diabetes screening + specialist referral.",
}[level]
return {
"status": "ok",
"score": risk,
"level": level,
"factors": factors,
"recommendation": rec,
}
# ═══════════════════════════════════════════════════════════════
# MANAGER AGENT
# ═══════════════════════════════════════════════════════════════
def manager_agent(image_idx, cgm_idx, question, preg, gluc, bp, skin, ins, bmi, ped, age,
use_synthetic_img, img_severity, use_synthetic_cgm, cgm_pattern):
img_result = image_agent(image_idx, use_synthetic_img, img_severity)
cgm_result = cgm_agent(cgm_idx, use_synthetic_cgm, cgm_pattern)
txt_result = text_agent(question)
risk_result = risk_agent(preg, gluc, bp, skin, ins, bmi, ped, age)
report = []
report.append("=" * 60)
report.append("🩺 DIABETES MULTI-MODAL ANALYSIS REPORT")
report.append("=" * 60)
report.append("\nπŸ“Έ AGENT 1: RETINAL IMAGE ANALYSIS")
report.append("-" * 40)
if img_result["status"] == "ok":
report.append(f"Source: {img_result.get('source', 'unknown').upper()}")
report.append(f"Severity: {img_result['severity']}")
report.append(f"Image Stats: Mean={img_result['mean_intensity']:.1f}, Std={img_result['std_intensity']:.1f}")
report.append(f"Recommendation: {img_result['recommendation']}")
else:
report.append(f"Status: {img_result.get('message', 'Unknown error')}")
report.append("\nπŸ“ˆ AGENT 2: CGM TIME-SERIES ANALYSIS")
report.append("-" * 40)
if cgm_result["status"] == "ok":
report.append(f"Source: {cgm_result.get('source', 'unknown').upper()}")
report.append(f"Patient: {cgm_result['patient_id']} | Readings: {cgm_result['readings']}")
report.append(f"Glucose: Mean={cgm_result['mean']} mg/dL, Std={cgm_result['std']}, CV={cgm_result['cv']}%")
report.append(f"Time in Range: {cgm_result['time_in_range']}%")
report.append(f"Hypo (<70): {cgm_result['hypo_pct']}% | Hyper (>180): {cgm_result['hyper_pct']}%")
report.append(f"Trend: {cgm_result['trend']}")
report.append(f"Assessment: {cgm_result['assessment']}")
else:
report.append(f"Status: {cgm_result.get('message', 'Unknown error')}")
report.append("\nπŸ’¬ AGENT 3: CLINICAL KNOWLEDGE RETRIEVAL")
report.append("-" * 40)
if txt_result["status"] == "ok":
for i, match in enumerate(txt_result["matches"], 1):
report.append(f"Match #{i}:")
report.append(f"Q: {match['question']}")
report.append(f"A: {match['answer']}")
report.append("")
else:
report.append(txt_result.get("general_advice", "No matches found"))
report.append("\nπŸ“‹ AGENT 4: TABULAR RISK ASSESSMENT")
report.append("-" * 40)
if risk_result["status"] == "ok":
report.append(f"Risk Score: {risk_result['score']}/5.0 β†’ {risk_result['level']} Risk")
report.append(f"Risk Factors: {', '.join(risk_result['factors']) or 'None'}")
report.append(f"Recommendation: {risk_result['recommendation']}")
report.append("\n" + "=" * 60)
report.append("πŸ”¬ INTEGRATED CLINICAL SYNTHESIS")
report.append("=" * 60)
risk_level = risk_result.get("level", "Unknown")
cgm_status = cgm_result.get("assessment", "Unknown") if cgm_result["status"] == "ok" else "Unknown"
img_severity = img_result.get("severity", "Unknown") if img_result["status"] == "ok" else "Unknown"
if risk_level == "High":
report.append("⚠️ URGENT: Multiple risk factors present. Endocrinology referral recommended.")
elif risk_level == "Moderate":
report.append("⚑ MODERATE RISK: Lifestyle intervention + 6-month follow-up advised.")
else:
report.append("βœ… LOW RISK: Continue routine screening.")
if "hyperglycemia" in cgm_status.lower() or "mixed" in cgm_status.lower():
report.append("🩸 CGM ALERT: Suboptimal glucose control. Review insulin/medication regimen.")
if img_severity not in ["Healthy", "Unknown"]:
report.append("πŸ‘οΈ RETINOPATHY ALERT: Abnormal fundus findings. Schedule ophthalmology evaluation.")
report.append("\n" + "=" * 60)
return "\n".join(report)
# ═══════════════════════════════════════════════════════════════
# GRADIO UI
# ═══════════════════════════════════════════════════════════════
with gr.Blocks(title="Diabetes Multi-Agent Pro") as demo:
gr.Markdown("""
# πŸ€– Diabetes Multi-Agent Pro
**Architecture**: Manager Agent + 4 Specialist Agents + Synthetic Data Generator
| Agent | Modality | Data Source |
|---|---|---|
| **Agent 1** | Retinal Images | Real (HF Hub) or Synthetic |
| **Agent 2** | CGM Time-Series | Real (HF Hub) or Synthetic |
| **Agent 3** | Clinical Text | Real Q&A (HF Hub) |
| **Agent 4** | Structured Risk | Real patient features |
""")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### πŸŽ›οΈ Data Source Selection")
use_synthetic_img = gr.Checkbox(label="Use Synthetic Fundus Image", value=False)
img_severity = gr.Dropdown(
choices=["Healthy", "Mild DR", "Moderate DR", "Severe DR", "Proliferate DR"],
value="Healthy",
label="Synthetic Image Severity"
)
use_synthetic_cgm = gr.Checkbox(label="Use Synthetic CGM Data", value=False)
cgm_pattern = gr.Dropdown(
choices=["stable", "volatile", "hypo_prone", "hyper_prone"],
value="stable",
label="Synthetic CGM Pattern"
)
gr.Markdown("### πŸ“Š Patient Data Input")
image_idx = gr.Number(value=5, label="Fundus Image Index (0-2749)", precision=0)
cgm_idx = gr.Number(value=3, label="CGM Index (0-197)", precision=0)
question = gr.Textbox(value="What should I eat with high BMI?", label="Clinical Question")
with gr.Row():
preg = gr.Number(value=2, label="Pregnancies", precision=0)
gluc = gr.Number(value=155, label="Glucose", precision=0)
bp = gr.Number(value=76, label="BP", precision=0)
skin = gr.Number(value=32, label="Skin Thickness", precision=0)
with gr.Row():
ins = gr.Number(value=95, label="Insulin", precision=0)
bmi = gr.Number(value=35.4, label="BMI")
ped = gr.Number(value=0.55, label="Pedigree")
age = gr.Number(value=38, label="Age", precision=0)
with gr.Column(scale=1):
gr.Markdown("### πŸ“‹ Integrated Report")
report_out = gr.Textbox(label="Multi-Agent Analysis Report", lines=40)
gr.Button("πŸš€ Run Multi-Agent Assessment").click(
manager_agent,
inputs=[image_idx, cgm_idx, question, preg, gluc, bp, skin, ins, bmi, ped, age,
use_synthetic_img, img_severity, use_synthetic_cgm, cgm_pattern],
outputs=report_out,
)
demo.launch()