saim1309's picture
Update app.py
b6175f7 verified
import gradio as gr
import cv2
import numpy as np
import io
from PIL import Image, ImageDraw, ImageFont
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
# ─── Font paths ───────────────────────────────────────────────────────────────
FONT_BOLD = "/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf"
FONT_REGULAR = "/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf"
def _font(size, bold=True):
try:
return ImageFont.truetype(FONT_BOLD if bold else FONT_REGULAR, size)
except Exception:
return ImageFont.load_default()
# ─── Colours (R,G,B) ──────────────────────────────────────────────────────────
C_VIOLET = (124, 58, 237)
C_VIOLET_DARK = ( 91, 33, 182)
C_VIOLET_LITE = (237, 233, 254)
C_TEAL = ( 13, 148, 136)
C_TEAL_LITE = (204, 251, 241)
C_AMBER = (217, 119, 6)
C_AMBER_LITE = (254, 243, 199)
C_ROSE = (225, 29, 72)
C_ROSE_LITE = (255, 228, 230)
C_SLATE = ( 30, 41, 59)
C_SLATE_MID = (100, 116, 139)
C_SLATE_LITE = (241, 245, 249)
C_WHITE = (255, 255, 255)
C_BLACK = ( 0, 0, 0)
C_RED = (220, 38, 38)
C_BG = (255, 255, 255) # page background
# ─── Cellpose model (lazy) ────────────────────────────────────────────────────
_model = None
def get_model():
global _model
if _model is None:
from cellpose import models
from huggingface_hub import hf_hub_download
fpath = hf_hub_download(repo_id="mouseland/cellpose-sam", filename="cpsam")
_model = models.CellposeModel(gpu=False, pretrained_model=fpath)
return _model
# ─── Image helpers ────────────────────────────────────────────────────────────
def normalize99(img):
X = img.copy().astype(np.float32)
p1, p99 = np.percentile(X, 1), np.percentile(X, 99)
return (X - p1) / (1e-10 + p99 - p1)
def image_resize(img, resize=1000):
ny, nx = img.shape[:2]
if max(ny, nx) > resize:
if ny > nx:
nx = int(nx / ny * resize); ny = resize
else:
ny = int(ny / nx * resize); nx = resize
img = cv2.resize(img, (nx, ny))
return img.astype(np.uint8)
def run_cellpose(img, model, max_iter=250, flow_threshold=0.4, cellprob_threshold=0.0):
masks, flows, _ = model.eval(
img, niter=max_iter,
flow_threshold=flow_threshold,
cellprob_threshold=cellprob_threshold,
)
return masks, flows
def build_outline_image(img, masks) -> Image.Image:
img_n = np.clip(normalize99(img), 0, 1)
outpix = []
contours, _ = cv2.findContours(
masks.astype(np.int32), cv2.RETR_FLOODFILL, cv2.CHAIN_APPROX_SIMPLE
)
for c in contours:
if len(c.astype(int).squeeze()) > 4:
outpix.append(cv2.approxPolyDP(c, 0.001, True)[:, 0, :])
h, w = img_n.shape[:2]
figsize = (6, 6 * h / w) if w >= h else (6 * w / h, 6)
fig = plt.figure(figsize=figsize, facecolor="k")
ax = fig.add_axes([0, 0, 1, 1])
ax.set_xlim([0, w]); ax.set_ylim([0, h])
ax.imshow(img_n[::-1], origin="upper", aspect="auto")
for o in outpix:
ax.plot(o[:, 0], h - o[:, 1], color=[1, 0, 0], lw=1)
ax.axis("off")
buf = io.BytesIO()
fig.savefig(buf, format="png", bbox_inches="tight", pad_inches=0)
buf.seek(0)
out = Image.open(buf).copy()
plt.close(fig)
return out
# ─── Drawing helpers ──────────────────────────────────────────────────────────
def _text_size(draw, text, font):
"""Return (width, height) of text."""
bbox = draw.textbbox((0, 0), text, font=font)
return bbox[2] - bbox[0], bbox[3] - bbox[1]
def _draw_rect(img, x0, y0, x1, y1, fill, border=None, border_width=2, radius=0):
"""Draw a filled rectangle with optional border on a PIL Image."""
draw = ImageDraw.Draw(img)
if radius > 0:
draw.rounded_rectangle([x0, y0, x1, y1], radius=radius, fill=fill,
outline=border, width=border_width if border else 0)
else:
draw.rectangle([x0, y0, x1, y1], fill=fill,
outline=border, width=border_width if border else 0)
def _draw_text_centred(img, cx, cy, text, font, color):
draw = ImageDraw.Draw(img)
tw, th = _text_size(draw, text, font)
draw.text((cx - tw // 2, cy - th // 2), text, font=font, fill=color)
def _draw_text_left(img, x, cy, text, font, color):
draw = ImageDraw.Draw(img)
_, th = _text_size(draw, text, font)
draw.text((x, cy - th // 2), text, font=font, fill=color)
# ─── Report image builder ─────────────────────────────────────────────────────
def build_report_image(segmented_pil: Image.Image, total_count: int) -> Image.Image:
"""
Renders the full report as a PIL Image with the same structure as the PDF:
β€’ Header : MLBench + tagline + teal rule
β€’ Body : [Grain Count Statistics table] | gap | [Segmentation Output image]
No footer line / page number.
"""
DPI = 150
PW_IN = 8.27 # A4 width in inches
PH_IN = 11.69 # A4 height in inches (we'll crop to content)
PW = int(PW_IN * DPI)
MARGIN = int(0.7 * DPI) # ~0.7 inch margin
# ── Fonts ─────────────────────────────────────────────────────────────
f_logo_ml = _font(int(0.28 * DPI)) # "ML" large
f_logo_b = _font(int(0.28 * DPI)) # "Bench" same size
f_tagline = _font(int(0.09 * DPI), bold=False)
f_sec_hdr = _font(int(0.11 * DPI)) # section bar text
f_col_hdr = _font(int(0.09 * DPI)) # table column headers
f_label = _font(int(0.10 * DPI)) # row labels
f_val_total = _font(int(0.13 * DPI)) # total count value (bigger)
f_val = _font(int(0.10 * DPI)) # other value cells
# ── Dimensions ────────────────────────────────────────────────────────
usable_w = PW - 2 * MARGIN
GAP = int(0.18 * DPI)
stat_w = int(usable_w * 0.43)
img_col_w = usable_w - stat_w - GAP
HDR_H = int(0.55 * DPI) # header area height
SEC_BAR_H = int(0.22 * DPI) # coloured section title bar
COL_HDR_H = int(0.18 * DPI) # table column header row
ROW_H = int(0.17 * DPI) # each data row
STRIPE_W = int(0.07 * DPI) # coloured left stripe on each row
TEAL_LINE = 3 # teal rule thickness
N_ROWS = 5
TABLE_H = COL_HDR_H + N_ROWS * ROW_H
# Total canvas height: margin + header + gap + sec_bar + content + margin
BODY_TOP = HDR_H + int(0.12 * DPI) # y where body starts
CONTENT_H = SEC_BAR_H + TABLE_H
CANVAS_H = BODY_TOP + CONTENT_H + MARGIN
# ── Create canvas ─────────────────────────────────────────────────────
img = Image.new("RGB", (PW, CANVAS_H), C_BG)
draw = ImageDraw.Draw(img)
# ── Header ────────────────────────────────────────────────────────────
# "ML" in red, "Bench" in black
logo_y = int(HDR_H * 0.38)
ml_w, _ = _text_size(draw, "ML", f_logo_ml)
draw.text((MARGIN, logo_y), "ML", font=f_logo_ml, fill=C_RED)
draw.text((MARGIN + ml_w, logo_y), "Bench", font=f_logo_b, fill=C_BLACK)
# Tagline right-aligned
tag = "Rice Grain Analysis Report"
tag_w, tag_h = _text_size(draw, tag, f_tagline)
draw.text((PW - MARGIN - tag_w, logo_y + 6), tag, font=f_tagline, fill=C_SLATE_MID)
# Teal horizontal rule
rule_y = HDR_H - 4
draw.rectangle([0, rule_y, PW, rule_y + TEAL_LINE], fill=C_TEAL)
# ── Section header bars ───────────────────────────────────────────────
stat_x = MARGIN
img_x = MARGIN + stat_w + GAP
stat_bar_y0 = BODY_TOP
stat_bar_y1 = BODY_TOP + SEC_BAR_H
# Teal bar β€” "Grain Count Statistics"
_draw_rect(img, stat_x, stat_bar_y0, stat_x + stat_w, stat_bar_y1, fill=C_TEAL)
_draw_text_centred(img, stat_x + stat_w // 2, (stat_bar_y0 + stat_bar_y1) // 2,
"Grain Count Statistics", f_sec_hdr, C_WHITE)
# Violet bar β€” "Segmentation Output"
_draw_rect(img, img_x, stat_bar_y0, img_x + img_col_w, stat_bar_y1, fill=C_VIOLET)
_draw_text_centred(img, img_x + img_col_w // 2, (stat_bar_y0 + stat_bar_y1) // 2,
"Segmentation Output", f_sec_hdr, C_WHITE)
# ── Stats table ───────────────────────────────────────────────────────
table_top = BODY_TOP + SEC_BAR_H
col_hdr_y0 = table_top
col_hdr_y1 = table_top + COL_HDR_H
# Column header background
_draw_rect(img, stat_x, col_hdr_y0, stat_x + stat_w, col_hdr_y1, fill=C_SLATE)
cat_cx = stat_x + STRIPE_W + (stat_w - STRIPE_W) // 2 - int((stat_w - STRIPE_W) * 0.18)
count_cx = stat_x + STRIPE_W + int((stat_w - STRIPE_W) * 0.78)
_draw_text_centred(img, cat_cx, (col_hdr_y0 + col_hdr_y1) // 2, "Category", f_col_hdr, C_WHITE)
_draw_text_centred(img, count_cx, (col_hdr_y0 + col_hdr_y1) // 2, "Count", f_col_hdr, C_WHITE)
stat_rows_def = [
("Total Rice Grain", str(total_count), C_VIOLET, C_VIOLET_LITE),
("Long Grain", "β€”", C_TEAL, C_TEAL_LITE),
("Short Grain", "β€”", C_AMBER, C_AMBER_LITE),
("Half Grain", "β€”", C_ROSE, C_ROSE_LITE),
("Broken Edge", "β€”", C_SLATE_MID,C_SLATE_LITE),
]
border_color = (203, 213, 225)
grid_color = (226, 232, 240)
for i, (label, val, accent, bg) in enumerate(stat_rows_def):
ry0 = table_top + COL_HDR_H + i * ROW_H
ry1 = ry0 + ROW_H
cy = (ry0 + ry1) // 2
# Row background
_draw_rect(img, stat_x, ry0, stat_x + stat_w, ry1, fill=bg)
# Accent stripe
_draw_rect(img, stat_x, ry0, stat_x + STRIPE_W, ry1, fill=accent)
# Label
f_lbl = f_label
_draw_text_left(img, stat_x + STRIPE_W + 8, cy, label, f_lbl, C_SLATE)
# Value
f_v = f_val_total if i == 0 else f_val
c_v = C_VIOLET if i == 0 else C_SLATE
vw, _ = _text_size(draw, val, f_v)
draw.text((stat_x + stat_w - vw - 14, cy - _text_size(draw, val, f_v)[1] // 2),
val, font=f_v, fill=c_v)
# Horizontal grid line
draw.rectangle([stat_x, ry1 - 1, stat_x + stat_w, ry1], fill=grid_color)
# Outer border of table (column header + rows)
draw.rectangle([stat_x, col_hdr_y0, stat_x + stat_w,
table_top + COL_HDR_H + N_ROWS * ROW_H], outline=border_color, width=1)
# ── Segmentation image ────────────────────────────────────────────────
# Fit segmented image to exactly match table height (SEC_BAR already above)
target_h = TABLE_H # must match table area below sec bar
target_w = img_col_w
seg_np = np.array(segmented_pil)
ih, iw = seg_np.shape[:2]
scale = min(target_w / iw, target_h / ih)
new_w = int(iw * scale)
new_h = int(ih * scale)
seg_resized = segmented_pil.resize((new_w, new_h), Image.BICUBIC)
# Black background box β€” same height as table
box_x0 = img_x
box_y0 = table_top # align top with table (below sec bar)
box_x1 = img_x + img_col_w
box_y1 = table_top + TABLE_H
_draw_rect(img, box_x0, box_y0, box_x1, box_y1,
fill=C_BLACK, border=C_VIOLET, border_width=2)
# Centre the image inside the black box
paste_x = box_x0 + (img_col_w - new_w) // 2
paste_y = box_y0 + (TABLE_H - new_h) // 2
img.paste(seg_resized, (paste_x, paste_y))
return img
# ─── Sample example images ────────────────────────────────────────────────────
SAMPLE_PATHS = [
"kainat.jpg",
"c9.jpg"
]
# ─── Status helpers ───────────────────────────────────────────────────────────
def make_status(level: str, message: str) -> dict:
icons = {"success": "βœ…", "warning": "⚠️", "error": "❌", "info": "ℹ️"}
icon = icons.get(level, "ℹ️")
return gr.update(value=f"{icon} {message}", visible=True)
# ─── Main processing ──────────────────────────────────────────────────────────
def process_image(pil_image):
# Returns: (report_image, status_update)
if pil_image is None:
return None, make_status("warning", "No image provided. Please upload or select a sample image first.")
try:
img_np = np.array(pil_image.convert("RGB"))
img_resized = image_resize(img_np, resize=1000)
model = get_model()
masks, _ = run_cellpose(img_resized, model)
total_count = int(masks.max())
if total_count == 0:
return None, make_status(
"warning",
"No rice grains were detected in this image. "
"Try a clearer photo or adjust the image contrast."
)
outline_pil = build_outline_image(img_resized, masks)
outline_pil = outline_pil.resize(
(img_resized.shape[1], img_resized.shape[0]), resample=Image.BICUBIC
)
report_img = build_report_image(outline_pil, total_count)
return (
report_img,
make_status("success", f"{total_count} rice grains detected. Report image shown on the right."),
)
except MemoryError:
return None, make_status("error", "Out of memory. Try uploading a smaller image.")
except Exception as e:
import traceback
traceback.print_exc()
return None, make_status("error", f"Unexpected error: {type(e).__name__}: {str(e)}")
# ─── UI ───────────────────────────────────────────────────────────────────────
THEME = gr.themes.Soft(
primary_hue="violet",
secondary_hue="indigo",
neutral_hue="slate",
font=gr.themes.GoogleFont("Inter"),
)
CSS = """
#run-btn { margin-top: 6px; }
#status-box textarea { font-size: 0.92rem; }
"""
with gr.Blocks(title="Rice Grain Counter") as demo:
gr.HTML("""
<div style="padding:18px 12px 10px 12px; background-color:#0F172A; border-radius:10px; margin-bottom:10px;">
<span style="font-size:2rem;font-weight:900;color:#F1F5F9;font-family:sans-serif;">
Rice Grain Counter
</span>
<p style="color:#CBD5E1;font-size:0.9rem;margin-top:4px;font-family:sans-serif;">
Upload a rice image to segment each grain and generate a report.
</p>
</div>
""")
with gr.Row(equal_height=False):
# ── LEFT COLUMN ───────────────────────────────────────────────────
with gr.Column(scale=1):
inp_image = gr.Image(type="pil", label="Upload Rice Image", height=270)
run_btn = gr.Button("πŸ” Analyse & Generate Report",
variant="primary", size="lg", elem_id="run-btn")
gr.Markdown("_Upload an image or click a sample below, then press **Analyse**._")
status_box = gr.Textbox(
label="Status",
value="",
interactive=False,
visible=False,
max_lines=3,
elem_id="status-box",
)
gr.Markdown("### Example Images _(click to load)_")
gr.Examples(
examples=[[p] for p in SAMPLE_PATHS],
inputs=inp_image,
label="",
examples_per_page=6,
)
# ── RIGHT COLUMN ──────────────────────────────────────────────────
with gr.Column(scale=1):
gr.Markdown("### Report")
report_out = gr.Image(
label="",
interactive=False,
)
run_btn.click(
fn=process_image,
inputs=[inp_image],
outputs=[report_out, status_box],
)
if __name__ == "__main__":
demo.launch(share=True, css=CSS)