import gradio as gr import cv2 import numpy as np import io from PIL import Image, ImageDraw, ImageFont import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt # ─── Font paths ─────────────────────────────────────────────────────────────── FONT_BOLD = "/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf" FONT_REGULAR = "/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf" def _font(size, bold=True): try: return ImageFont.truetype(FONT_BOLD if bold else FONT_REGULAR, size) except Exception: return ImageFont.load_default() # ─── Colours (R,G,B) ────────────────────────────────────────────────────────── C_VIOLET = (124, 58, 237) C_VIOLET_DARK = ( 91, 33, 182) C_VIOLET_LITE = (237, 233, 254) C_TEAL = ( 13, 148, 136) C_TEAL_LITE = (204, 251, 241) C_AMBER = (217, 119, 6) C_AMBER_LITE = (254, 243, 199) C_ROSE = (225, 29, 72) C_ROSE_LITE = (255, 228, 230) C_SLATE = ( 30, 41, 59) C_SLATE_MID = (100, 116, 139) C_SLATE_LITE = (241, 245, 249) C_WHITE = (255, 255, 255) C_BLACK = ( 0, 0, 0) C_RED = (220, 38, 38) C_BG = (255, 255, 255) # page background # ─── Cellpose model (lazy) ──────────────────────────────────────────────────── _model = None def get_model(): global _model if _model is None: from cellpose import models from huggingface_hub import hf_hub_download fpath = hf_hub_download(repo_id="mouseland/cellpose-sam", filename="cpsam") _model = models.CellposeModel(gpu=False, pretrained_model=fpath) return _model # ─── Image helpers ──────────────────────────────────────────────────────────── def normalize99(img): X = img.copy().astype(np.float32) p1, p99 = np.percentile(X, 1), np.percentile(X, 99) return (X - p1) / (1e-10 + p99 - p1) def image_resize(img, resize=1000): ny, nx = img.shape[:2] if max(ny, nx) > resize: if ny > nx: nx = int(nx / ny * resize); ny = resize else: ny = int(ny / nx * resize); nx = resize img = cv2.resize(img, (nx, ny)) return img.astype(np.uint8) def run_cellpose(img, model, max_iter=250, flow_threshold=0.4, cellprob_threshold=0.0): masks, flows, _ = model.eval( img, niter=max_iter, flow_threshold=flow_threshold, cellprob_threshold=cellprob_threshold, ) return masks, flows def build_outline_image(img, masks) -> Image.Image: img_n = np.clip(normalize99(img), 0, 1) outpix = [] contours, _ = cv2.findContours( masks.astype(np.int32), cv2.RETR_FLOODFILL, cv2.CHAIN_APPROX_SIMPLE ) for c in contours: if len(c.astype(int).squeeze()) > 4: outpix.append(cv2.approxPolyDP(c, 0.001, True)[:, 0, :]) h, w = img_n.shape[:2] figsize = (6, 6 * h / w) if w >= h else (6 * w / h, 6) fig = plt.figure(figsize=figsize, facecolor="k") ax = fig.add_axes([0, 0, 1, 1]) ax.set_xlim([0, w]); ax.set_ylim([0, h]) ax.imshow(img_n[::-1], origin="upper", aspect="auto") for o in outpix: ax.plot(o[:, 0], h - o[:, 1], color=[1, 0, 0], lw=1) ax.axis("off") buf = io.BytesIO() fig.savefig(buf, format="png", bbox_inches="tight", pad_inches=0) buf.seek(0) out = Image.open(buf).copy() plt.close(fig) return out # ─── Drawing helpers ────────────────────────────────────────────────────────── def _text_size(draw, text, font): """Return (width, height) of text.""" bbox = draw.textbbox((0, 0), text, font=font) return bbox[2] - bbox[0], bbox[3] - bbox[1] def _draw_rect(img, x0, y0, x1, y1, fill, border=None, border_width=2, radius=0): """Draw a filled rectangle with optional border on a PIL Image.""" draw = ImageDraw.Draw(img) if radius > 0: draw.rounded_rectangle([x0, y0, x1, y1], radius=radius, fill=fill, outline=border, width=border_width if border else 0) else: draw.rectangle([x0, y0, x1, y1], fill=fill, outline=border, width=border_width if border else 0) def _draw_text_centred(img, cx, cy, text, font, color): draw = ImageDraw.Draw(img) tw, th = _text_size(draw, text, font) draw.text((cx - tw // 2, cy - th // 2), text, font=font, fill=color) def _draw_text_left(img, x, cy, text, font, color): draw = ImageDraw.Draw(img) _, th = _text_size(draw, text, font) draw.text((x, cy - th // 2), text, font=font, fill=color) # ─── Report image builder ───────────────────────────────────────────────────── def build_report_image(segmented_pil: Image.Image, total_count: int) -> Image.Image: """ Renders the full report as a PIL Image with the same structure as the PDF: • Header : MLBench + tagline + teal rule • Body : [Grain Count Statistics table] | gap | [Segmentation Output image] No footer line / page number. """ DPI = 150 PW_IN = 8.27 # A4 width in inches PH_IN = 11.69 # A4 height in inches (we'll crop to content) PW = int(PW_IN * DPI) MARGIN = int(0.7 * DPI) # ~0.7 inch margin # ── Fonts ───────────────────────────────────────────────────────────── f_logo_ml = _font(int(0.28 * DPI)) # "ML" large f_logo_b = _font(int(0.28 * DPI)) # "Bench" same size f_tagline = _font(int(0.09 * DPI), bold=False) f_sec_hdr = _font(int(0.11 * DPI)) # section bar text f_col_hdr = _font(int(0.09 * DPI)) # table column headers f_label = _font(int(0.10 * DPI)) # row labels f_val_total = _font(int(0.13 * DPI)) # total count value (bigger) f_val = _font(int(0.10 * DPI)) # other value cells # ── Dimensions ──────────────────────────────────────────────────────── usable_w = PW - 2 * MARGIN GAP = int(0.18 * DPI) stat_w = int(usable_w * 0.43) img_col_w = usable_w - stat_w - GAP HDR_H = int(0.55 * DPI) # header area height SEC_BAR_H = int(0.22 * DPI) # coloured section title bar COL_HDR_H = int(0.18 * DPI) # table column header row ROW_H = int(0.17 * DPI) # each data row STRIPE_W = int(0.07 * DPI) # coloured left stripe on each row TEAL_LINE = 3 # teal rule thickness N_ROWS = 5 TABLE_H = COL_HDR_H + N_ROWS * ROW_H # Total canvas height: margin + header + gap + sec_bar + content + margin BODY_TOP = HDR_H + int(0.12 * DPI) # y where body starts CONTENT_H = SEC_BAR_H + TABLE_H CANVAS_H = BODY_TOP + CONTENT_H + MARGIN # ── Create canvas ───────────────────────────────────────────────────── img = Image.new("RGB", (PW, CANVAS_H), C_BG) draw = ImageDraw.Draw(img) # ── Header ──────────────────────────────────────────────────────────── # "ML" in red, "Bench" in black logo_y = int(HDR_H * 0.38) ml_w, _ = _text_size(draw, "ML", f_logo_ml) draw.text((MARGIN, logo_y), "ML", font=f_logo_ml, fill=C_RED) draw.text((MARGIN + ml_w, logo_y), "Bench", font=f_logo_b, fill=C_BLACK) # Tagline right-aligned tag = "Rice Grain Analysis Report" tag_w, tag_h = _text_size(draw, tag, f_tagline) draw.text((PW - MARGIN - tag_w, logo_y + 6), tag, font=f_tagline, fill=C_SLATE_MID) # Teal horizontal rule rule_y = HDR_H - 4 draw.rectangle([0, rule_y, PW, rule_y + TEAL_LINE], fill=C_TEAL) # ── Section header bars ─────────────────────────────────────────────── stat_x = MARGIN img_x = MARGIN + stat_w + GAP stat_bar_y0 = BODY_TOP stat_bar_y1 = BODY_TOP + SEC_BAR_H # Teal bar — "Grain Count Statistics" _draw_rect(img, stat_x, stat_bar_y0, stat_x + stat_w, stat_bar_y1, fill=C_TEAL) _draw_text_centred(img, stat_x + stat_w // 2, (stat_bar_y0 + stat_bar_y1) // 2, "Grain Count Statistics", f_sec_hdr, C_WHITE) # Violet bar — "Segmentation Output" _draw_rect(img, img_x, stat_bar_y0, img_x + img_col_w, stat_bar_y1, fill=C_VIOLET) _draw_text_centred(img, img_x + img_col_w // 2, (stat_bar_y0 + stat_bar_y1) // 2, "Segmentation Output", f_sec_hdr, C_WHITE) # ── Stats table ─────────────────────────────────────────────────────── table_top = BODY_TOP + SEC_BAR_H col_hdr_y0 = table_top col_hdr_y1 = table_top + COL_HDR_H # Column header background _draw_rect(img, stat_x, col_hdr_y0, stat_x + stat_w, col_hdr_y1, fill=C_SLATE) cat_cx = stat_x + STRIPE_W + (stat_w - STRIPE_W) // 2 - int((stat_w - STRIPE_W) * 0.18) count_cx = stat_x + STRIPE_W + int((stat_w - STRIPE_W) * 0.78) _draw_text_centred(img, cat_cx, (col_hdr_y0 + col_hdr_y1) // 2, "Category", f_col_hdr, C_WHITE) _draw_text_centred(img, count_cx, (col_hdr_y0 + col_hdr_y1) // 2, "Count", f_col_hdr, C_WHITE) stat_rows_def = [ ("Total Rice Grain", str(total_count), C_VIOLET, C_VIOLET_LITE), ("Long Grain", "—", C_TEAL, C_TEAL_LITE), ("Short Grain", "—", C_AMBER, C_AMBER_LITE), ("Half Grain", "—", C_ROSE, C_ROSE_LITE), ("Broken Edge", "—", C_SLATE_MID,C_SLATE_LITE), ] border_color = (203, 213, 225) grid_color = (226, 232, 240) for i, (label, val, accent, bg) in enumerate(stat_rows_def): ry0 = table_top + COL_HDR_H + i * ROW_H ry1 = ry0 + ROW_H cy = (ry0 + ry1) // 2 # Row background _draw_rect(img, stat_x, ry0, stat_x + stat_w, ry1, fill=bg) # Accent stripe _draw_rect(img, stat_x, ry0, stat_x + STRIPE_W, ry1, fill=accent) # Label f_lbl = f_label _draw_text_left(img, stat_x + STRIPE_W + 8, cy, label, f_lbl, C_SLATE) # Value f_v = f_val_total if i == 0 else f_val c_v = C_VIOLET if i == 0 else C_SLATE vw, _ = _text_size(draw, val, f_v) draw.text((stat_x + stat_w - vw - 14, cy - _text_size(draw, val, f_v)[1] // 2), val, font=f_v, fill=c_v) # Horizontal grid line draw.rectangle([stat_x, ry1 - 1, stat_x + stat_w, ry1], fill=grid_color) # Outer border of table (column header + rows) draw.rectangle([stat_x, col_hdr_y0, stat_x + stat_w, table_top + COL_HDR_H + N_ROWS * ROW_H], outline=border_color, width=1) # ── Segmentation image ──────────────────────────────────────────────── # Fit segmented image to exactly match table height (SEC_BAR already above) target_h = TABLE_H # must match table area below sec bar target_w = img_col_w seg_np = np.array(segmented_pil) ih, iw = seg_np.shape[:2] scale = min(target_w / iw, target_h / ih) new_w = int(iw * scale) new_h = int(ih * scale) seg_resized = segmented_pil.resize((new_w, new_h), Image.BICUBIC) # Black background box — same height as table box_x0 = img_x box_y0 = table_top # align top with table (below sec bar) box_x1 = img_x + img_col_w box_y1 = table_top + TABLE_H _draw_rect(img, box_x0, box_y0, box_x1, box_y1, fill=C_BLACK, border=C_VIOLET, border_width=2) # Centre the image inside the black box paste_x = box_x0 + (img_col_w - new_w) // 2 paste_y = box_y0 + (TABLE_H - new_h) // 2 img.paste(seg_resized, (paste_x, paste_y)) return img # ─── Sample example images ──────────────────────────────────────────────────── SAMPLE_PATHS = [ "kainat.jpg", "c9.jpg" ] # ─── Status helpers ─────────────────────────────────────────────────────────── def make_status(level: str, message: str) -> dict: icons = {"success": "✅", "warning": "⚠️", "error": "❌", "info": "ℹ️"} icon = icons.get(level, "ℹ️") return gr.update(value=f"{icon} {message}", visible=True) # ─── Main processing ────────────────────────────────────────────────────────── def process_image(pil_image): # Returns: (report_image, status_update) if pil_image is None: return None, make_status("warning", "No image provided. Please upload or select a sample image first.") try: img_np = np.array(pil_image.convert("RGB")) img_resized = image_resize(img_np, resize=1000) model = get_model() masks, _ = run_cellpose(img_resized, model) total_count = int(masks.max()) if total_count == 0: return None, make_status( "warning", "No rice grains were detected in this image. " "Try a clearer photo or adjust the image contrast." ) outline_pil = build_outline_image(img_resized, masks) outline_pil = outline_pil.resize( (img_resized.shape[1], img_resized.shape[0]), resample=Image.BICUBIC ) report_img = build_report_image(outline_pil, total_count) return ( report_img, make_status("success", f"{total_count} rice grains detected. Report image shown on the right."), ) except MemoryError: return None, make_status("error", "Out of memory. Try uploading a smaller image.") except Exception as e: import traceback traceback.print_exc() return None, make_status("error", f"Unexpected error: {type(e).__name__}: {str(e)}") # ─── UI ─────────────────────────────────────────────────────────────────────── THEME = gr.themes.Soft( primary_hue="violet", secondary_hue="indigo", neutral_hue="slate", font=gr.themes.GoogleFont("Inter"), ) CSS = """ #run-btn { margin-top: 6px; } #status-box textarea { font-size: 0.92rem; } """ with gr.Blocks(title="Rice Grain Counter") as demo: gr.HTML("""
Rice Grain Counter

Upload a rice image to segment each grain and generate a report.

""") with gr.Row(equal_height=False): # ── LEFT COLUMN ─────────────────────────────────────────────────── with gr.Column(scale=1): inp_image = gr.Image(type="pil", label="Upload Rice Image", height=270) run_btn = gr.Button("🔍 Analyse & Generate Report", variant="primary", size="lg", elem_id="run-btn") gr.Markdown("_Upload an image or click a sample below, then press **Analyse**._") status_box = gr.Textbox( label="Status", value="", interactive=False, visible=False, max_lines=3, elem_id="status-box", ) gr.Markdown("### Example Images _(click to load)_") gr.Examples( examples=[[p] for p in SAMPLE_PATHS], inputs=inp_image, label="", examples_per_page=6, ) # ── RIGHT COLUMN ────────────────────────────────────────────────── with gr.Column(scale=1): gr.Markdown("### Report") report_out = gr.Image( label="", interactive=False, ) run_btn.click( fn=process_image, inputs=[inp_image], outputs=[report_out, status_box], ) if __name__ == "__main__": demo.launch(share=True, css=CSS)