Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import cv2 | |
| import numpy as np | |
| import io | |
| from PIL import Image, ImageDraw, ImageFont | |
| import matplotlib | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt | |
| # βββ Font paths βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| FONT_BOLD = "/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf" | |
| FONT_REGULAR = "/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf" | |
| def _font(size, bold=True): | |
| try: | |
| return ImageFont.truetype(FONT_BOLD if bold else FONT_REGULAR, size) | |
| except Exception: | |
| return ImageFont.load_default() | |
| # βββ Colours (R,G,B) ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| C_VIOLET = (124, 58, 237) | |
| C_VIOLET_DARK = ( 91, 33, 182) | |
| C_VIOLET_LITE = (237, 233, 254) | |
| C_TEAL = ( 13, 148, 136) | |
| C_TEAL_LITE = (204, 251, 241) | |
| C_AMBER = (217, 119, 6) | |
| C_AMBER_LITE = (254, 243, 199) | |
| C_ROSE = (225, 29, 72) | |
| C_ROSE_LITE = (255, 228, 230) | |
| C_SLATE = ( 30, 41, 59) | |
| C_SLATE_MID = (100, 116, 139) | |
| C_SLATE_LITE = (241, 245, 249) | |
| C_WHITE = (255, 255, 255) | |
| C_BLACK = ( 0, 0, 0) | |
| C_RED = (220, 38, 38) | |
| C_BG = (255, 255, 255) # page background | |
| # βββ Cellpose model (lazy) ββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| _model = None | |
| def get_model(): | |
| global _model | |
| if _model is None: | |
| from cellpose import models | |
| from huggingface_hub import hf_hub_download | |
| fpath = hf_hub_download(repo_id="mouseland/cellpose-sam", filename="cpsam") | |
| _model = models.CellposeModel(gpu=False, pretrained_model=fpath) | |
| return _model | |
| # βββ Image helpers ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def normalize99(img): | |
| X = img.copy().astype(np.float32) | |
| p1, p99 = np.percentile(X, 1), np.percentile(X, 99) | |
| return (X - p1) / (1e-10 + p99 - p1) | |
| def image_resize(img, resize=1000): | |
| ny, nx = img.shape[:2] | |
| if max(ny, nx) > resize: | |
| if ny > nx: | |
| nx = int(nx / ny * resize); ny = resize | |
| else: | |
| ny = int(ny / nx * resize); nx = resize | |
| img = cv2.resize(img, (nx, ny)) | |
| return img.astype(np.uint8) | |
| def run_cellpose(img, model, max_iter=250, flow_threshold=0.4, cellprob_threshold=0.0): | |
| masks, flows, _ = model.eval( | |
| img, niter=max_iter, | |
| flow_threshold=flow_threshold, | |
| cellprob_threshold=cellprob_threshold, | |
| ) | |
| return masks, flows | |
| def build_outline_image(img, masks) -> Image.Image: | |
| img_n = np.clip(normalize99(img), 0, 1) | |
| outpix = [] | |
| contours, _ = cv2.findContours( | |
| masks.astype(np.int32), cv2.RETR_FLOODFILL, cv2.CHAIN_APPROX_SIMPLE | |
| ) | |
| for c in contours: | |
| if len(c.astype(int).squeeze()) > 4: | |
| outpix.append(cv2.approxPolyDP(c, 0.001, True)[:, 0, :]) | |
| h, w = img_n.shape[:2] | |
| figsize = (6, 6 * h / w) if w >= h else (6 * w / h, 6) | |
| fig = plt.figure(figsize=figsize, facecolor="k") | |
| ax = fig.add_axes([0, 0, 1, 1]) | |
| ax.set_xlim([0, w]); ax.set_ylim([0, h]) | |
| ax.imshow(img_n[::-1], origin="upper", aspect="auto") | |
| for o in outpix: | |
| ax.plot(o[:, 0], h - o[:, 1], color=[1, 0, 0], lw=1) | |
| ax.axis("off") | |
| buf = io.BytesIO() | |
| fig.savefig(buf, format="png", bbox_inches="tight", pad_inches=0) | |
| buf.seek(0) | |
| out = Image.open(buf).copy() | |
| plt.close(fig) | |
| return out | |
| # βββ Drawing helpers ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _text_size(draw, text, font): | |
| """Return (width, height) of text.""" | |
| bbox = draw.textbbox((0, 0), text, font=font) | |
| return bbox[2] - bbox[0], bbox[3] - bbox[1] | |
| def _draw_rect(img, x0, y0, x1, y1, fill, border=None, border_width=2, radius=0): | |
| """Draw a filled rectangle with optional border on a PIL Image.""" | |
| draw = ImageDraw.Draw(img) | |
| if radius > 0: | |
| draw.rounded_rectangle([x0, y0, x1, y1], radius=radius, fill=fill, | |
| outline=border, width=border_width if border else 0) | |
| else: | |
| draw.rectangle([x0, y0, x1, y1], fill=fill, | |
| outline=border, width=border_width if border else 0) | |
| def _draw_text_centred(img, cx, cy, text, font, color): | |
| draw = ImageDraw.Draw(img) | |
| tw, th = _text_size(draw, text, font) | |
| draw.text((cx - tw // 2, cy - th // 2), text, font=font, fill=color) | |
| def _draw_text_left(img, x, cy, text, font, color): | |
| draw = ImageDraw.Draw(img) | |
| _, th = _text_size(draw, text, font) | |
| draw.text((x, cy - th // 2), text, font=font, fill=color) | |
| # βββ Report image builder βββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def build_report_image(segmented_pil: Image.Image, total_count: int) -> Image.Image: | |
| """ | |
| Renders the full report as a PIL Image with the same structure as the PDF: | |
| β’ Header : MLBench + tagline + teal rule | |
| β’ Body : [Grain Count Statistics table] | gap | [Segmentation Output image] | |
| No footer line / page number. | |
| """ | |
| DPI = 150 | |
| PW_IN = 8.27 # A4 width in inches | |
| PH_IN = 11.69 # A4 height in inches (we'll crop to content) | |
| PW = int(PW_IN * DPI) | |
| MARGIN = int(0.7 * DPI) # ~0.7 inch margin | |
| # ββ Fonts βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| f_logo_ml = _font(int(0.28 * DPI)) # "ML" large | |
| f_logo_b = _font(int(0.28 * DPI)) # "Bench" same size | |
| f_tagline = _font(int(0.09 * DPI), bold=False) | |
| f_sec_hdr = _font(int(0.11 * DPI)) # section bar text | |
| f_col_hdr = _font(int(0.09 * DPI)) # table column headers | |
| f_label = _font(int(0.10 * DPI)) # row labels | |
| f_val_total = _font(int(0.13 * DPI)) # total count value (bigger) | |
| f_val = _font(int(0.10 * DPI)) # other value cells | |
| # ββ Dimensions ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| usable_w = PW - 2 * MARGIN | |
| GAP = int(0.18 * DPI) | |
| stat_w = int(usable_w * 0.43) | |
| img_col_w = usable_w - stat_w - GAP | |
| HDR_H = int(0.55 * DPI) # header area height | |
| SEC_BAR_H = int(0.22 * DPI) # coloured section title bar | |
| COL_HDR_H = int(0.18 * DPI) # table column header row | |
| ROW_H = int(0.17 * DPI) # each data row | |
| STRIPE_W = int(0.07 * DPI) # coloured left stripe on each row | |
| TEAL_LINE = 3 # teal rule thickness | |
| N_ROWS = 5 | |
| TABLE_H = COL_HDR_H + N_ROWS * ROW_H | |
| # Total canvas height: margin + header + gap + sec_bar + content + margin | |
| BODY_TOP = HDR_H + int(0.12 * DPI) # y where body starts | |
| CONTENT_H = SEC_BAR_H + TABLE_H | |
| CANVAS_H = BODY_TOP + CONTENT_H + MARGIN | |
| # ββ Create canvas βββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| img = Image.new("RGB", (PW, CANVAS_H), C_BG) | |
| draw = ImageDraw.Draw(img) | |
| # ββ Header ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # "ML" in red, "Bench" in black | |
| logo_y = int(HDR_H * 0.38) | |
| ml_w, _ = _text_size(draw, "ML", f_logo_ml) | |
| draw.text((MARGIN, logo_y), "ML", font=f_logo_ml, fill=C_RED) | |
| draw.text((MARGIN + ml_w, logo_y), "Bench", font=f_logo_b, fill=C_BLACK) | |
| # Tagline right-aligned | |
| tag = "Rice Grain Analysis Report" | |
| tag_w, tag_h = _text_size(draw, tag, f_tagline) | |
| draw.text((PW - MARGIN - tag_w, logo_y + 6), tag, font=f_tagline, fill=C_SLATE_MID) | |
| # Teal horizontal rule | |
| rule_y = HDR_H - 4 | |
| draw.rectangle([0, rule_y, PW, rule_y + TEAL_LINE], fill=C_TEAL) | |
| # ββ Section header bars βββββββββββββββββββββββββββββββββββββββββββββββ | |
| stat_x = MARGIN | |
| img_x = MARGIN + stat_w + GAP | |
| stat_bar_y0 = BODY_TOP | |
| stat_bar_y1 = BODY_TOP + SEC_BAR_H | |
| # Teal bar β "Grain Count Statistics" | |
| _draw_rect(img, stat_x, stat_bar_y0, stat_x + stat_w, stat_bar_y1, fill=C_TEAL) | |
| _draw_text_centred(img, stat_x + stat_w // 2, (stat_bar_y0 + stat_bar_y1) // 2, | |
| "Grain Count Statistics", f_sec_hdr, C_WHITE) | |
| # Violet bar β "Segmentation Output" | |
| _draw_rect(img, img_x, stat_bar_y0, img_x + img_col_w, stat_bar_y1, fill=C_VIOLET) | |
| _draw_text_centred(img, img_x + img_col_w // 2, (stat_bar_y0 + stat_bar_y1) // 2, | |
| "Segmentation Output", f_sec_hdr, C_WHITE) | |
| # ββ Stats table βββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| table_top = BODY_TOP + SEC_BAR_H | |
| col_hdr_y0 = table_top | |
| col_hdr_y1 = table_top + COL_HDR_H | |
| # Column header background | |
| _draw_rect(img, stat_x, col_hdr_y0, stat_x + stat_w, col_hdr_y1, fill=C_SLATE) | |
| cat_cx = stat_x + STRIPE_W + (stat_w - STRIPE_W) // 2 - int((stat_w - STRIPE_W) * 0.18) | |
| count_cx = stat_x + STRIPE_W + int((stat_w - STRIPE_W) * 0.78) | |
| _draw_text_centred(img, cat_cx, (col_hdr_y0 + col_hdr_y1) // 2, "Category", f_col_hdr, C_WHITE) | |
| _draw_text_centred(img, count_cx, (col_hdr_y0 + col_hdr_y1) // 2, "Count", f_col_hdr, C_WHITE) | |
| stat_rows_def = [ | |
| ("Total Rice Grain", str(total_count), C_VIOLET, C_VIOLET_LITE), | |
| ("Long Grain", "β", C_TEAL, C_TEAL_LITE), | |
| ("Short Grain", "β", C_AMBER, C_AMBER_LITE), | |
| ("Half Grain", "β", C_ROSE, C_ROSE_LITE), | |
| ("Broken Edge", "β", C_SLATE_MID,C_SLATE_LITE), | |
| ] | |
| border_color = (203, 213, 225) | |
| grid_color = (226, 232, 240) | |
| for i, (label, val, accent, bg) in enumerate(stat_rows_def): | |
| ry0 = table_top + COL_HDR_H + i * ROW_H | |
| ry1 = ry0 + ROW_H | |
| cy = (ry0 + ry1) // 2 | |
| # Row background | |
| _draw_rect(img, stat_x, ry0, stat_x + stat_w, ry1, fill=bg) | |
| # Accent stripe | |
| _draw_rect(img, stat_x, ry0, stat_x + STRIPE_W, ry1, fill=accent) | |
| # Label | |
| f_lbl = f_label | |
| _draw_text_left(img, stat_x + STRIPE_W + 8, cy, label, f_lbl, C_SLATE) | |
| # Value | |
| f_v = f_val_total if i == 0 else f_val | |
| c_v = C_VIOLET if i == 0 else C_SLATE | |
| vw, _ = _text_size(draw, val, f_v) | |
| draw.text((stat_x + stat_w - vw - 14, cy - _text_size(draw, val, f_v)[1] // 2), | |
| val, font=f_v, fill=c_v) | |
| # Horizontal grid line | |
| draw.rectangle([stat_x, ry1 - 1, stat_x + stat_w, ry1], fill=grid_color) | |
| # Outer border of table (column header + rows) | |
| draw.rectangle([stat_x, col_hdr_y0, stat_x + stat_w, | |
| table_top + COL_HDR_H + N_ROWS * ROW_H], outline=border_color, width=1) | |
| # ββ Segmentation image ββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Fit segmented image to exactly match table height (SEC_BAR already above) | |
| target_h = TABLE_H # must match table area below sec bar | |
| target_w = img_col_w | |
| seg_np = np.array(segmented_pil) | |
| ih, iw = seg_np.shape[:2] | |
| scale = min(target_w / iw, target_h / ih) | |
| new_w = int(iw * scale) | |
| new_h = int(ih * scale) | |
| seg_resized = segmented_pil.resize((new_w, new_h), Image.BICUBIC) | |
| # Black background box β same height as table | |
| box_x0 = img_x | |
| box_y0 = table_top # align top with table (below sec bar) | |
| box_x1 = img_x + img_col_w | |
| box_y1 = table_top + TABLE_H | |
| _draw_rect(img, box_x0, box_y0, box_x1, box_y1, | |
| fill=C_BLACK, border=C_VIOLET, border_width=2) | |
| # Centre the image inside the black box | |
| paste_x = box_x0 + (img_col_w - new_w) // 2 | |
| paste_y = box_y0 + (TABLE_H - new_h) // 2 | |
| img.paste(seg_resized, (paste_x, paste_y)) | |
| return img | |
| # βββ Sample example images ββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| SAMPLE_PATHS = [ | |
| "kainat.jpg", | |
| "c9.jpg" | |
| ] | |
| # βββ Status helpers βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def make_status(level: str, message: str) -> dict: | |
| icons = {"success": "β ", "warning": "β οΈ", "error": "β", "info": "βΉοΈ"} | |
| icon = icons.get(level, "βΉοΈ") | |
| return gr.update(value=f"{icon} {message}", visible=True) | |
| # βββ Main processing ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def process_image(pil_image): | |
| # Returns: (report_image, status_update) | |
| if pil_image is None: | |
| return None, make_status("warning", "No image provided. Please upload or select a sample image first.") | |
| try: | |
| img_np = np.array(pil_image.convert("RGB")) | |
| img_resized = image_resize(img_np, resize=1000) | |
| model = get_model() | |
| masks, _ = run_cellpose(img_resized, model) | |
| total_count = int(masks.max()) | |
| if total_count == 0: | |
| return None, make_status( | |
| "warning", | |
| "No rice grains were detected in this image. " | |
| "Try a clearer photo or adjust the image contrast." | |
| ) | |
| outline_pil = build_outline_image(img_resized, masks) | |
| outline_pil = outline_pil.resize( | |
| (img_resized.shape[1], img_resized.shape[0]), resample=Image.BICUBIC | |
| ) | |
| report_img = build_report_image(outline_pil, total_count) | |
| return ( | |
| report_img, | |
| make_status("success", f"{total_count} rice grains detected. Report image shown on the right."), | |
| ) | |
| except MemoryError: | |
| return None, make_status("error", "Out of memory. Try uploading a smaller image.") | |
| except Exception as e: | |
| import traceback | |
| traceback.print_exc() | |
| return None, make_status("error", f"Unexpected error: {type(e).__name__}: {str(e)}") | |
| # βββ UI βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| THEME = gr.themes.Soft( | |
| primary_hue="violet", | |
| secondary_hue="indigo", | |
| neutral_hue="slate", | |
| font=gr.themes.GoogleFont("Inter"), | |
| ) | |
| CSS = """ | |
| #run-btn { margin-top: 6px; } | |
| #status-box textarea { font-size: 0.92rem; } | |
| """ | |
| with gr.Blocks(title="Rice Grain Counter") as demo: | |
| gr.HTML(""" | |
| <div style="padding:18px 12px 10px 12px; background-color:#0F172A; border-radius:10px; margin-bottom:10px;"> | |
| <span style="font-size:2rem;font-weight:900;color:#F1F5F9;font-family:sans-serif;"> | |
| Rice Grain Counter | |
| </span> | |
| <p style="color:#CBD5E1;font-size:0.9rem;margin-top:4px;font-family:sans-serif;"> | |
| Upload a rice image to segment each grain and generate a report. | |
| </p> | |
| </div> | |
| """) | |
| with gr.Row(equal_height=False): | |
| # ββ LEFT COLUMN βββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| with gr.Column(scale=1): | |
| inp_image = gr.Image(type="pil", label="Upload Rice Image", height=270) | |
| run_btn = gr.Button("π Analyse & Generate Report", | |
| variant="primary", size="lg", elem_id="run-btn") | |
| gr.Markdown("_Upload an image or click a sample below, then press **Analyse**._") | |
| status_box = gr.Textbox( | |
| label="Status", | |
| value="", | |
| interactive=False, | |
| visible=False, | |
| max_lines=3, | |
| elem_id="status-box", | |
| ) | |
| gr.Markdown("### Example Images _(click to load)_") | |
| gr.Examples( | |
| examples=[[p] for p in SAMPLE_PATHS], | |
| inputs=inp_image, | |
| label="", | |
| examples_per_page=6, | |
| ) | |
| # ββ RIGHT COLUMN ββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Report") | |
| report_out = gr.Image( | |
| label="", | |
| interactive=False, | |
| ) | |
| run_btn.click( | |
| fn=process_image, | |
| inputs=[inp_image], | |
| outputs=[report_out, status_box], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(share=True, css=CSS) |