| import gradio as gr |
| import torch |
| from carvekit.api.interface import Interface |
| from carvekit.ml.wrap.basnet import BASNET |
| from carvekit.ml.wrap.deeplab_v3 import DeepLabV3 |
| from carvekit.ml.wrap.fba_matting import FBAMatting |
| from carvekit.ml.wrap.tracer_b7 import TracerUniversalB7 |
| from carvekit.ml.wrap.u2net import U2NET |
| from carvekit.pipelines.postprocessing import MattingMethod |
| from carvekit.pipelines.preprocessing import PreprocessingStub |
| from carvekit.trimap.generator import TrimapGenerator |
|
|
| device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
| segment_net = { |
| "U2NET": U2NET(device=device, batch_size=1), |
| "BASNET": BASNET(device=device, batch_size=1), |
| "DeepLabV3": DeepLabV3(device=device, batch_size=1), |
| "TracerUniversalB7": TracerUniversalB7(device=device, batch_size=1) |
| } |
|
|
| fba = FBAMatting(device=device, |
| input_tensor_size=2048, |
| batch_size=1) |
|
|
| trimap = TrimapGenerator() |
|
|
| preprocessing = PreprocessingStub() |
|
|
| postprocessing = MattingMethod(matting_module=fba, |
| trimap_generator=trimap, |
| device=device) |
|
|
| method_choices = [k for k, v in segment_net.items()] |
|
|
|
|
| def generate_trimap(method, original): |
| mask = segment_net[method]([original]) |
| return trimap(original_image=original, mask=mask[0]) |
|
|
|
|
| def predict(method, image): |
| method = segment_net[method] |
| return Interface(pre_pipe=preprocessing, |
| post_pipe=postprocessing, |
| seg_pipe=method)([image])[0] |
|
|
|
|
| footer = r""" |
| <center> |
| <img src='https://raw.githubusercontent.com/leonelhs/image-background-remove-tool/master/docs/imgs/logo.png' alt='CarveKit' width="200" height="80"> |
| </br> |
| <b> |
| Demo based on <a href='https://github.com/OPHoperHPO/image-background-remove-tool'>CarveKit</a> |
| </b> |
| </center> |
| """ |
|
|
| with gr.Blocks(title="CarveKit") as app: |
| gr.Markdown("<center><h1><b>CarveKit</b></h1></center>") |
| gr.HTML("<center><h3>High-quality image background removal</h3></center>") |
|
|
| with gr.Tabs() as tabs: |
| with gr.TabItem("Remove background", id=0): |
| with gr.Row(equal_height=False): |
| with gr.Column(): |
| input_img = gr.Image(type="pil", label="Input image") |
| drp_itf = gr.Dropdown( |
| value="TracerUniversalB7", |
| label="Segmentor model", |
| choices=method_choices) |
| run_btn = gr.Button(variant="primary") |
| with gr.Column(): |
| output_img = gr.Image(type="pil", label="result") |
|
|
| run_btn.click(predict, [drp_itf, input_img], [output_img]) |
|
|
| with gr.TabItem("Trimap generator", id=1): |
| with gr.Row(equal_height=False): |
| with gr.Column(): |
| trimap_input = gr.Image(type="pil", label="Input image") |
| drp_itf = gr.Dropdown( |
| value="TracerUniversalB7", |
| label="Segmentor model", |
| choices=method_choices) |
| trimap_btn = gr.Button(variant="primary") |
| with gr.Column(): |
| trimap_output = gr.Image(type="pil", label="result") |
|
|
| trimap_btn.click(generate_trimap, [drp_itf, trimap_input], [trimap_output]) |
|
|
| with gr.Row(): |
| gr.HTML(footer) |
|
|
| app.queue() |
| app.launch(share=False, debug=True, show_error=True) |
|
|