| from core import runner |
| import torch |
| from torch import tensor |
| from PIL import Image |
| import numpy as np |
| import torch.nn.functional as F |
| import gradio as gr |
|
|
| |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| print(f'{device.type=}') |
|
|
| description = '<p> Choose an example below; OR <br>\ |
| Upload by yourself: <br>\ |
| 1. Upload any test image (query) with any target object you wish to segment <br>\ |
| 2. Upload another image (support) with the target object or a variation of it <br>\ |
| 3. Upload a binary mask that segments the target objet in the support image <br>\ |
| </p>' |
| |
| |
| example_episodes = [ |
| ['./imgs/549870_35.jpg', './imgs/457070_00.jpg', './imgs/457070_00.png'], |
| ['./imgs/ISIC_0000372.jpg', './imgs/ISIC_0013176.jpg', './imgs/ISIC_0013176_segmentation.png'], |
| ['./imgs/d_r_450_.jpg', './imgs/d_r_465_.jpg', './imgs/d_r_465_.bmp'], |
| ['./imgs/CHNCXR_0282_0.png', './imgs/CHNCXR_0324_0.png', './imgs/CHNCXR_0324_0_mask.png'], |
| ['./imgs/1.jpg', './imgs/5.jpg', './imgs/5.png'], |
| ['./imgs/cake1.png', './imgs/cake2.png', './imgs/cake2_mask.png'] |
| ] |
| blank_img = './imgs/blank.png' |
|
|
| gr_img = lambda name: gr.Image(label=name, sources=['upload', 'webcam'], type="pil") |
| inputs = [gr_img('Query Img'), gr_img('Support Img'), gr_img('Support Mask'), gr.Checkbox(label='re-adapt')] |
| if device.type=='cpu': |
| inputs.append(gr.Checkbox(label='Confirm CPU run (CHOOSE ONLY WHEN REQUESTED)')) |
|
|
| def prepare_feat_maker(): |
| config = runner.makeConfig() |
| class DummyDataset: |
| class_ids = [0] |
| fake_feat_maker = runner.makeFeatureMaker(DummyDataset(), config, device=device) |
| return fake_feat_maker |
|
|
| feat_maker = prepare_feat_maker() |
| has_fit = False |
|
|
| |
|
|
| def reset_layers(): |
| global feat_maker |
| feat_maker = prepare_feat_maker() |
| |
| def prepare_batch(q_img_pil, s_img_pil, s_mask_pil): |
| from data.dataset import FSSDataset |
| FSSDataset.initialize(img_size=400,datapath='') |
| q_img_tensor = FSSDataset.transform(q_img_pil) |
| s_img_tensor = FSSDataset.transform(s_img_pil) |
| s_mask_tensor = torch.tensor(np.array(s_mask_pil.convert('L'))) |
| s_mask_tensor = F.interpolate(s_mask_tensor.unsqueeze(0).unsqueeze(0).float(), s_img_tensor.size()[-2:], mode='nearest').squeeze() |
| add_batch_dim = lambda t: t.unsqueeze(0) |
| add_kshot_dim = lambda t: t.unsqueeze(1) |
| fake_batch = {'query_img':add_batch_dim(q_img_tensor), 'support_imgs':add_kshot_dim(add_batch_dim(s_img_tensor)), 'support_masks':add_kshot_dim(add_batch_dim(s_mask_tensor)), 'class_id':tensor([0])} |
| return fake_batch |
|
|
| norm = lambda t: (t - t.min()) / (t.max() - t.min()) |
| def overlay(img, mask): |
| |
| return norm(img)*0.5 + mask[:,:,np.newaxis]*0.5 |
| |
| def from_model(q_img, s_img, s_mask): |
| batch = prepare_batch(q_img, s_img, s_mask) |
| sseval = runner.SingleSampleEval(batch, feat_maker) |
| pred_logits, pred_mask = sseval.forward() |
| global has_fit |
| has_fit = True |
| |
| return norm(pred_logits[0].numpy()), overlay(batch['query_img'][0].permute(1,2,0).numpy(), pred_mask[0].numpy()) |
|
|
| def predict(q,s,m,re_adapt,confirmed): |
| print(f'predict with {re_adapt=}, {confirmed=}') |
| print(f'{type(q)=}') |
| is_cache_run = re_adapt is None and confirmed is None |
| is_example = any([(np.array_equal(np.array(m),np.array(Image.open(e[2])))) for e in example_episodes]) |
| print(f'{is_example=}') |
|
|
| if is_cache_run: |
| reset_layers() |
| pred = from_model(q,s,m) |
| msg = 'Results ready.' |
| return msg, *pred |
| elif re_adapt: |
| if confirmed: |
| reset_layers() |
| pred = from_model(q,s,m) |
| msg = "Results ready.\nRemember to untick 're-adapt' if you wish to predict more images with the same parameters." |
| return msg, *pred |
| else: |
| msg = "You chose to re-adapt but are on CPU.\nThis may take 1 minute on your local machine or 4 minutes on huggingface space.\nSelect 'Confirm CPU run' to start." |
| return msg, blank_img, blank_img |
| else: |
| if is_example: |
| msg = "Cached results for example have been shown previously already.\nTo view it again, click the example again.\nTo run adaption again from scratch, select 're-adapt'." |
| return msg, blank_img, blank_img |
| else: |
| if has_fit: |
| pred = from_model(q,s,m) |
| msg = "Results predicted based on layers fitted from previous run.\nIf you wish to re-adapt, select 're-adapt'." |
| return msg, *pred |
| else: |
| msg = "This is the first time you predict own images.\nThe attached layers need to be fitted.\nPlease select 're-adapt'." |
| return msg, blank_img, blank_img |
| |
| gradio_app = gr.Interface( |
| fn=predict, |
| inputs=inputs, |
| outputs=[gr.Textbox(label="Status"), gr.Image(label="Coarse Query Prediction"), gr.Image(label="Mask Prediction")], |
| description=description, |
| examples=example_episodes, |
| title="abcdfss", |
| ) |
|
|
| gradio_app.launch() |