| |
|
|
| import os |
| import pathlib |
|
|
| import gradio as gr |
| import numpy as np |
| import PIL.Image as Image |
|
|
| from model import Model, random_color, vis_mask |
|
|
| model = Model() |
|
|
|
|
| def run(image_path, threshold, max_num_mask): |
| image = np.asarray(Image.open(image_path).convert('RGB')) |
| masks = model(image_path, threshold, max_num_mask) |
| for mask in masks: |
| image = vis_mask(image, mask, random_color(rgb=True)) |
| return image |
|
|
|
|
| DESCRIPTION = '# [MaskCut](https://github.com/facebookresearch/CutLER)' |
|
|
| paths = sorted(pathlib.Path('CutLER/maskcut/imgs').glob('*.jpg')) |
|
|
| with gr.Blocks(css='style.css') as demo: |
| gr.Markdown(DESCRIPTION) |
| with gr.Row(): |
| with gr.Column(): |
| image = gr.Image(label='Input image', type='filepath') |
| threshold = gr.Slider( |
| label='Threshold used for producing binary graph', |
| minimum=0, |
| maximum=1, |
| step=0.01, |
| value=0.15) |
| max_masks = gr.Slider( |
| label='The maximum number of pseudo-masks per image', |
| minimum=1, |
| maximum=20, |
| step=1, |
| value=6) |
| run_button = gr.Button('Run') |
| with gr.Column(): |
| result = gr.Image(label='Result') |
|
|
| inputs = [image, threshold, max_masks] |
| gr.Examples(examples=[[path.as_posix(), 0.15, 6] for path in paths], |
| inputs=inputs, |
| outputs=result, |
| fn=run, |
| cache_examples=os.getenv('CACHE_EXAMPLES') == '1') |
|
|
| run_button.click(fn=run, inputs=inputs, outputs=result, api_name='run') |
| demo.queue(max_size=20).launch() |
|
|