| import subprocess |
| import sys |
| print("Reinstalling mmcv") |
| subprocess.check_call([sys.executable, "-m", "pip", "uninstall", "-y", "mmcv-full==1.3.17"]) |
| subprocess.check_call([sys.executable, "-m", "pip", "install", "mmcv-full==1.3.17", "-f", "https://download.openmmlab.com/mmcv/dist/cpu/torch1.10.0/index.html"]) |
| print("mmcv install complete") |
|
|
| |
|
|
| from gradio.outputs import Label |
| from icevision.all import * |
| from icevision.models.checkpoint import * |
| import PIL |
| import gradio as gr |
| import os |
|
|
| |
| checkpoint_path = "models/model_checkpoint.pth" |
| checkpoint_and_model = model_from_checkpoint(checkpoint_path) |
| model = checkpoint_and_model["model"] |
| model_type = checkpoint_and_model["model_type"] |
| class_map = checkpoint_and_model["class_map"] |
|
|
| |
| img_size = checkpoint_and_model["img_size"] |
| valid_tfms = tfms.A.Adapter([*tfms.A.resize_and_pad(img_size), tfms.A.Normalize()]) |
|
|
| for root, dirs, files in os.walk(r"sample_images/"): |
| for filename in files: |
| print("Loading sample image:", filename) |
|
|
|
|
| |
| example_images = [["sample_images/" + file] for file in files] |
| |
| examples = [ |
| [example_images[0], False, True, 0.5], |
| [example_images[1], True, True, 0.5], |
| [example_images[2], False, True, 0.7], |
| [example_images[3], True, True, 0.7], |
| [example_images[4], False, True, 0.5], |
| [example_images[5], False, True, 0.5], |
| [example_images[6], False, True, 0.6], |
| [example_images[7], False, True, 0.6], |
| ] |
|
|
|
|
| def show_preds(input_image, display_label, display_bbox, detection_threshold): |
| if detection_threshold == 0: |
| detection_threshold = 0.5 |
| img = PIL.Image.fromarray(input_image, "RGB") |
| pred_dict = model_type.end2end_detect( |
| img, |
| valid_tfms, |
| model, |
| class_map=class_map, |
| detection_threshold=detection_threshold, |
| display_label=display_label, |
| display_bbox=display_bbox, |
| return_img=True, |
| font_size=16, |
| label_color="#FF59D6", |
| ) |
| return pred_dict["img"], len(pred_dict["detection"]["bboxes"]) |
|
|
|
|
| |
| display_chkbox_label = gr.inputs.Checkbox(label="Label", default=False) |
| display_chkbox_box = gr.inputs.Checkbox(label="Box", default=True) |
| detection_threshold_slider = gr.inputs.Slider( |
| minimum=0, maximum=1, step=0.1, default=0.5, label="Detection Threshold" |
| ) |
| outputs = [ |
| gr.outputs.Image(type="pil", label="RetinaNet Inference"), |
| gr.outputs.Textbox(type="number", label="Microalgae Count"), |
| ] |
|
|
| article = "<p style='text-align: center'><a href='https://dicksonneoh.com/' target='_blank'>Blog post</a></p>" |
|
|
| |
| gr_interface = gr.Interface( |
| fn=show_preds, |
| inputs=[ |
| "image", |
| display_chkbox_label, |
| display_chkbox_box, |
| detection_threshold_slider, |
| ], |
| outputs=outputs, |
| title="Microalgae Detector with RetinaNet", |
| description="This RetinaNet model counts microalgaes on a given image. Upload an image or click an example image below to use.", |
| article=article, |
| examples=examples, |
| ) |
| |
| |
| |
| |
| gr_interface.launch(inline=False, share=False, debug=True) |