endoscopy / app.py
chengwang96
update UI
3222a03
import gradio as gr
from PIL import Image
import os
import random
# ========== 配置图片路径 ==========
ENHANCE_INPUT_DIR = "sr_input"
ENHANCE_OUTPUT_DIR = "sr_output"
ENHANCE_IMG_NAMES = [f"img_{i}.jpg" for i in range(1, 30)]
ENHANCE_CHOICES = ["请选择样例图片或者上传"] + ENHANCE_IMG_NAMES
ENHANCE_DEFAULT = ENHANCE_IMG_NAMES[0]
SEG_INPUT_DIR = "seg_input"
SEG_OUTPUT_DIR = "seg_output"
SEG_IMG_NAMES = [f"img_{i}.png" for i in range(1, 30)]
SEG_CHOICES = ["请选择样例图片或者上传"] + SEG_IMG_NAMES
SEG_DEFAULT = SEG_IMG_NAMES[0]
DIAG_INPUT_DIR = "diag_input"
DIAG_IMG_NAMES = [f"img_{i}.jpg" for i in range(1, 32)]
DIAG_CHOICES = ["请选择样例图片或者上传"] + DIAG_IMG_NAMES
DIAG_DEFAULT = DIAG_IMG_NAMES[0]
KVASIR_CLASSES = [
"正常粘膜 (normal-z-line)",
"食管静脉曲张 (esophagitis)",
"正常盲肠 (normal-cecum)",
"息肉 (polyps)",
"溃疡 (ulcerative-colitis)",
"糜烂 (dyed-lifted-polyps)",
"出血 (dyed-resection-margins)",
"正常幽门 (normal-pylorus)"
]
SUGGESTION_DICT = {
"正常粘膜 (normal-z-line)": "建议:无需特殊治疗,常规随访。",
"食管静脉曲张 (esophagitis)": "建议:根据分级考虑内镜下治疗或药物治疗。",
"正常盲肠 (normal-cecum)": "建议:无需特殊治疗,常规随访。",
"息肉 (polyps)": "建议:考虑内镜下息肉切除,并随访。",
"溃疡 (ulcerative-colitis)": "建议:药物治疗,必要时内镜下活检。",
"糜烂 (dyed-lifted-polyps)": "建议:药物保护胃黏膜,可考虑内镜下进一步处理。",
"出血 (dyed-resection-margins)": "建议:内镜下止血处理,密切观察。",
"正常幽门 (normal-pylorus)": "建议:无需特殊治疗,常规随访。"
}
def open_img(img_path):
if not os.path.exists(img_path):
img = Image.new("RGB", (300, 300), (200, 200, 200))
else:
img = Image.open(img_path)
w, h = img.size
if w != 300 or h != 300:
img = img.resize((300, 300), Image.LANCZOS)
return img
def show_enhance_input(img_name):
if img_name in (None, "", "请选择样例图片或者上传"):
return None
path = os.path.join(ENHANCE_INPUT_DIR, img_name)
return open_img(path)
def enhance_demo(img_name):
if img_name in (None, "", "请选择样例图片或者上传"):
return None
path = os.path.join(ENHANCE_OUTPUT_DIR, img_name)
return open_img(path)
def show_seg_input(img_name):
if img_name in (None, "", "请选择样例图片或者上传"):
return None
path = os.path.join(SEG_INPUT_DIR, img_name)
return open_img(path)
def segment_demo(img_name):
if img_name in (None, "", "请选择样例图片或者上传"):
return None
path = os.path.join(SEG_OUTPUT_DIR, img_name)
return open_img(path)
def show_diag_input(img_name):
if img_name in (None, "", "请选择样例图片或者上传"):
return None
path = os.path.join(DIAG_INPUT_DIR, img_name)
return open_img(path)
def diagnose_demo(img_name):
if img_name in (None, "", "请选择样例图片或者上传"):
return [], "", ""
idx = int(os.path.splitext(img_name)[0].replace("img_", ""))
idx2cat = [
(range(1, 5), "糜烂 (dyed-lifted-polyps)"),
(range(5, 9), "出血 (dyed-resection-margins)"),
(range(9, 13), "食管静脉曲张 (esophagitis)"),
(range(13, 17), "正常盲肠 (normal-cecum)"),
(range(17, 21), "正常幽门 (normal-pylorus)"),
(range(21, 25), "正常粘膜 (normal-z-line)"),
(range(25, 29), "息肉 (polyps)"),
(range(29, 33), "溃疡 (ulcerative-colitis)")
]
main_cat = None
for r, cat in idx2cat:
if idx in r:
main_cat = cat
break
if main_cat is None:
main_cat = random.choice(KVASIR_CLASSES)
n_cat = len(KVASIR_CLASSES)
probs = [0.0 for _ in range(n_cat)]
main_idx = KVASIR_CLASSES.index(main_cat)
main_prob = random.uniform(0.85, 0.99)
rest = 1 - main_prob
other_probs_raw = [random.uniform(0.01, 1) for _ in range(n_cat-1)]
other_probs_norm = [x / sum(other_probs_raw) * rest for x in other_probs_raw]
pidx = 0
for i in range(n_cat):
if i == main_idx:
probs[i] = main_prob
else:
probs[i] = other_probs_norm[pidx]
pidx += 1
result_table = [[KVASIR_CLASSES[i], f"{probs[i]:.4f}"] for i in range(n_cat)]
result_text = f"诊断类别:{main_cat}(概率最大)"
suggestion = SUGGESTION_DICT.get(main_cat, "建议:请咨询医生。")
return result_table, result_text, suggestion
css = """
#main-title {
text-align: center;
font-size: 44px;
font-weight: bold;
margin-bottom: 18px;
margin-top: 18px;
letter-spacing: 2px;
}
.blue-dash-border {
border: 2.5px dashed #164fa0 !important;
border-radius: 22px !important;
padding: 0 !important;
margin-bottom: 16px !important;
box-sizing: border-box;
}
.big-group {
background: #fff !important;
border: 2px solid #e0e2e5 !important;
border-radius: 18px !important;
padding: 28px 18px 24px 18px !important;
margin: 16px !important;
box-shadow: 0 3px 12px 2px rgba(60,64,67,.14);
}
.big-title {
font-size: 28px;
font-weight: bold;
margin-bottom: 18px;
color: #343434;
letter-spacing: 1px;
text-align: center;
}
/* ===== 按钮公共属性合并 ===== */
.orange-btn,
.orange-btn-diag {
background: #A7C0DE !important;
color: #111 !important;
font-size: 32px;
font-weight: bold;
border: none #5900c2 !important;
border-radius: 16px !important;
padding: 12px 20px !important;
width: auto !important;
}
.gray-btn,
.gray-btn-diag {
background: #f2f2f2 !important;
color: #111 !important;
font-size: 32px;
font-weight: bold;
border: 3px dashed #6c91c2 !important;
border-radius: 16px !important;
padding: 12px 20px !important;
width: auto !important;
}
/* ===== 不同之处仅在于左右 margin ===== */
.orange-btn {
margin-left: 48px !important;
margin-right: 48px !important;
}
.gray-btn {
margin-left: 48px !important;
margin-right: 48px !important;
}
/* ====== 诊断模块按钮独立CSS ====== */
.orange-btn-diag {
margin-left: 24px !important;
margin-right: 24px !important;
}
.gray-btn-diag {
margin-left: 24px !important;
margin-right: 24px !important;
}
.button-row {
margin-top: 8px;
}
.image-label-container,
.image-label-container * {
margin: 0 !important;
padding: 0 !important;
line-height: 0 !important;
font-size: 0 !important;
border-spacing: 0 !important;
box-sizing: border-box !important;
}
.image-label-container {
width: 300px !important;
margin-left: auto !important;
margin-right: auto !important;
display: flex !important;
flex-direction: column !important;
align-items: center !important;
justify-content: center !important;
position: relative;
text-align: center !important;
padding: 0 !important;
border: none !important;
}
.image-label-container .gr-image,
.image-label-container .gr-image-preview {
width: 300px !important;
height: 300px !important;
min-width: 300px !important;
min-height: 300px !important;
max-width: 300px !important;
max-height: 300px !important;
display: flex !important;
align-items: center !important;
justify-content: center !important;
background: #fff !important;
border-radius: 10px 10px 0 0 !important;
box-shadow: 0 2px 12px 0 rgba(60,64,67,.08);
}
.image-label-container img {
display: block !important;
vertical-align: bottom !important;
width: 300px !important;
height: 300px !important;
margin: 0 auto !important;
padding: 0 !important;
border-radius: 10px 10px 0 0 !important;
background: #fff !important;
box-shadow: 0 2px 12px 0 rgba(60,64,67,.08);
}
.image-label-container .img-label-bar {
margin: 0 !important;
padding-top: 0 !important;
padding-bottom: 0 !important;
line-height: normal !important;
font-size: 22px !important;
width: 300px !important;
text-align: center !important;
background: rgb(17, 26, 110);
color: #fff;
border-bottom-left-radius: 10px;
border-bottom-right-radius: 10px;
font-weight: normal;
box-sizing: border-box;
align-self: center !important;
}
.gradio-container label,
.gradio-container .gr-input-label,
.gradio-container .block-label {
font-size: 36px !important;
font-weight: 500 !important;
color: #222 !important;
letter-spacing: 0.5px;
}
.gradio-container select,
.gradio-container option,
.gradio-container input[type="text"] {
font-size: 22px !important;
}
.gradio-container .gr-dataframe,
.gradio-container .gr-dataframe table,
.gradio-container .gr-dataframe table th,
.gradio-container .gr-dataframe table td {
font-size: 28px !important;
}
.gradio-container .gr-dataframe .block-label,
.gradio-container .gr-dataframe label {
font-size: 28px !important;
font-weight: 600 !important;
}
.gradio-container .gr-textbox,
.gradio-container .gr-textbox textarea,
.gradio-container .gr-textbox .block-label,
.gradio-container .gr-textbox label {
font-size: 28px !important;
}
"""
def reset_enhance():
return "请选择样例图片或者上传", None, None
def reset_seg():
return "请选择样例图片或者上传", None, None
def reset_diag():
return "请选择样例图片或者上传", None, [], "", ""
with gr.Blocks(title="消化道疾病智能分析系统") as demo:
gr.HTML(f"<style>{css}</style>")
gr.HTML("<div id='main-title'>消化道疾病智能分析系统</div>")
# ============ 第一行:增强 & 分割 ============
with gr.Row():
# ---------- 图像增强 ----------
with gr.Column():
with gr.Group(elem_classes="blue-dash-border"):
with gr.Group(elem_classes="big-group"):
gr.HTML("<div class='big-title'>图像增强模块</div>")
enhance_select = gr.Dropdown(
choices=ENHANCE_CHOICES,
value=ENHANCE_DEFAULT,
label="请选择样例图片或者上传",
filterable=True
)
with gr.Row():
with gr.Column():
with gr.Group(elem_classes="image-label-container"):
enhance_input_img = gr.Image(
show_label=False, interactive=False,
width=300, height=300
)
gr.HTML("<div class='img-label-bar'>原始图片</div>")
with gr.Column():
with gr.Group(elem_classes="image-label-container"):
enhance_output_img = gr.Image(
show_label=False, interactive=False,
width=300, height=300
)
gr.HTML("<div class='img-label-bar'>增强结果图片</div>")
with gr.Row(elem_classes="button-row"):
enhance_reset_btn = gr.Button("清空", elem_classes="gray-btn")
enhance_btn = gr.Button("点击执行图片增强", elem_classes="orange-btn")
enhance_select.change(
fn=show_enhance_input,
inputs=enhance_select,
outputs=enhance_input_img
)
enhance_btn.click(
fn=enhance_demo,
inputs=enhance_select,
outputs=enhance_output_img
)
enhance_reset_btn.click(
fn=reset_enhance,
inputs=None,
outputs=[enhance_select, enhance_input_img, enhance_output_img]
)
# ---------- 图像分割 ----------
with gr.Column():
with gr.Group(elem_classes="blue-dash-border"):
with gr.Group(elem_classes="big-group"):
gr.HTML("<div class='big-title'>图像分割模块</div>")
seg_select = gr.Dropdown(
choices=SEG_CHOICES,
value=SEG_DEFAULT,
label="请选择样例图片或者上传",
filterable=True
)
with gr.Row():
with gr.Column():
with gr.Group(elem_classes="image-label-container"):
seg_input_img = gr.Image(
show_label=False, interactive=False,
width=300, height=300
)
gr.HTML("<div class='img-label-bar'>输入图片</div>")
with gr.Column():
with gr.Group(elem_classes="image-label-container"):
seg_output_img = gr.Image(
show_label=False, interactive=False,
width=300, height=300
)
gr.HTML("<div class='img-label-bar'>分割结果图片</div>")
with gr.Row(elem_classes="button-row"):
seg_reset_btn = gr.Button("清空", elem_classes="gray-btn")
seg_btn = gr.Button("点击执行分割", elem_classes="orange-btn")
seg_select.change(fn=show_seg_input, inputs=seg_select, outputs=seg_input_img)
seg_btn.click(fn=segment_demo, inputs=seg_select, outputs=seg_output_img)
seg_reset_btn.click(fn=reset_seg, inputs=None, outputs=[seg_select, seg_input_img, seg_output_img])
# ============ 第二行:疾病诊断 ============
with gr.Row():
with gr.Group(elem_classes="blue-dash-border"):
with gr.Group(elem_classes="big-group"):
gr.HTML("<div class='big-title'>疾病分类模块</div>")
with gr.Row():
# ---------- 左列:输入图片 + 按钮 ----------
with gr.Column():
diag_select = gr.Dropdown(
choices=DIAG_CHOICES,
value=DIAG_DEFAULT,
label="请选择样例图片或者上传",
filterable=True
)
with gr.Column():
with gr.Group(elem_classes="image-label-container"):
diag_input_img = gr.Image(
show_label=False, interactive=False,
width=300, height=300
)
gr.HTML("<div class='img-label-bar'>输入图片</div>")
with gr.Row(elem_classes="button-row"):
# === 这里用新class名 ===
diag_reset_btn = gr.Button("清空", elem_classes="gray-btn-diag")
diag_btn = gr.Button("点击执行诊断", elem_classes="orange-btn-diag")
diag_select.change(fn=show_diag_input, inputs=diag_select, outputs=diag_input_img)
# ---------- 中列:结果表格(新增label) ----------
with gr.Column():
diag_table = gr.Dataframe(
headers=["诊断类别", "模型预测概率"]
)
# ---------- 右列:诊断结果 & 建议 ----------
with gr.Column():
diag_result = gr.Textbox(label="诊断结果")
diag_suggestion = gr.Textbox(label="建议的治疗方案")
diag_btn.click(
fn=diagnose_demo,
inputs=diag_select,
outputs=[diag_table, diag_result, diag_suggestion]
)
diag_reset_btn.click(
fn=reset_diag,
inputs=None,
outputs=[diag_select, diag_input_img, diag_table, diag_result, diag_suggestion]
)
# ====== 页面加载时自动显示默认结果 ======
demo.load(fn=show_enhance_input, inputs=gr.State(ENHANCE_DEFAULT), outputs=enhance_input_img)
demo.load(fn=enhance_demo, inputs=gr.State(ENHANCE_DEFAULT), outputs=enhance_output_img)
demo.load(fn=show_seg_input, inputs=gr.State(SEG_DEFAULT), outputs=seg_input_img)
demo.load(fn=segment_demo, inputs=gr.State(SEG_DEFAULT), outputs=seg_output_img)
demo.load(fn=show_diag_input, inputs=gr.State(DIAG_DEFAULT), outputs=diag_input_img)
demo.load(fn=diagnose_demo, inputs=gr.State(DIAG_DEFAULT), outputs=[diag_table, diag_result, diag_suggestion])
demo.launch()