| import os |
| import sys |
| import skimage.io |
| import numpy as np |
| import pandas as pd |
| import gradio as gr |
| from PIL import Image |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.utils.data import DataLoader, Dataset |
| from efficientnet_pytorch import model as enet |
|
|
| import matplotlib.pyplot as plt |
| from tqdm import tqdm_notebook as tqdm |
| tile_size = 256 |
| image_size = 256 |
| n_tiles = 36 |
| batch_size = 8 |
| num_workers = 4 |
| Image.MAX_IMAGE_PIXELS = None |
|
|
|
|
|
|
| |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| class enetv2(nn.Module): |
| def __init__(self, backbone, out_dim): |
| super(enetv2, self).__init__() |
| self.enet = enet.EfficientNet.from_name(backbone) |
| self.myfc = nn.Linear(self.enet._fc.in_features, out_dim) |
| self.enet._fc = nn.Identity() |
|
|
| def extract(self, x): |
| return self.enet(x) |
|
|
| def forward(self, x): |
| x = self.extract(x) |
| x = self.myfc(x) |
| return x |
| |
| |
| def load_models(model_files): |
| models = [] |
| for model_f in model_files: |
| model_f = os.path.join( model_f) |
| backbone = 'efficientnet-b0' |
| model = enetv2(backbone, out_dim=5) |
| model.load_state_dict(torch.load(model_f, map_location=lambda storage, loc: storage), strict=True) |
| model.eval() |
| model.to(device) |
| models.append(model) |
| print(f'{model_f} loaded!') |
| return models |
|
|
|
|
| model_files = [ |
| 'base_implementation' |
| ] |
|
|
| models = load_models(model_files) |
|
|
| def get_tiles(img, mode=0): |
| result = [] |
| h, w, c = img.shape |
| pad_h = (tile_size - h % tile_size) % tile_size + ((tile_size * mode) // 2) |
| pad_w = (tile_size - w % tile_size) % tile_size + ((tile_size * mode) // 2) |
|
|
| img2 = np.pad(img,[[pad_h // 2, pad_h - pad_h // 2], [pad_w // 2,pad_w - pad_w//2], [0,0]], constant_values=255) |
| img3 = img2.reshape( |
| img2.shape[0] // tile_size, |
| tile_size, |
| img2.shape[1] // tile_size, |
| tile_size, |
| 3 |
| ) |
|
|
| img3 = img3.transpose(0,2,1,3,4).reshape(-1, tile_size, tile_size,3) |
| n_tiles_with_info = (img3.reshape(img3.shape[0],-1).sum(1) < tile_size ** 2 * 3 * 255).sum() |
| if len(img) < n_tiles: |
| img3 = np.pad(img3,[[0,N-len(img3)],[0,0],[0,0],[0,0]], constant_values=255) |
| idxs = np.argsort(img3.reshape(img3.shape[0],-1).sum(-1))[:n_tiles] |
| img3 = img3[idxs] |
| for i in range(len(img3)): |
| result.append({'img':img3[i], 'idx':i}) |
| return result, n_tiles_with_info >= n_tiles |
|
|
| def getitem(img, tile_mode): |
| tiff_file = img |
| image = skimage.io.MultiImage(tiff_file)[0] |
| tiles, OK = get_tiles(image, tile_mode) |
|
|
| idxes = np.random.choice(list(range(n_tiles)), n_tiles, replace=False) |
|
|
| n_row_tiles = int(np.sqrt(n_tiles)) |
| images = np.zeros((image_size * n_row_tiles, image_size * n_row_tiles, 3)) |
| for h in range(n_row_tiles): |
| for w in range(n_row_tiles): |
| i = h * n_row_tiles + w |
|
|
| if len(tiles) > idxes[i]: |
| this_img = tiles[idxes[i]]['img'] |
| else: |
| this_img = np.ones((image_size, image_size, 3)).astype(np.uint8) * 255 |
| this_img = 255 - this_img |
| h1 = h * image_size |
| w1 = w * image_size |
| images[h1:h1 + image_size, w1:w1 + image_size] = this_img |
|
|
| images = images.astype(np.float32) |
| images /= 255 |
| images = images.transpose(2, 0, 1) |
|
|
| |
| return torch.tensor(images).unsqueeze(0) |
|
|
|
|
| def predict_label(im): |
| data1 = getitem(im, 0) |
| data2 = getitem(im, 2) |
| LOGITS = [] |
| LOGITS2 = [] |
|
|
| with torch.no_grad(): |
| data1 = data1.to(device) |
| logits = models[0](data1) |
| LOGITS.append(logits) |
|
|
| data2 = data2.to(device) |
| logits2 = models[0](data2) |
| LOGITS2.append(logits2) |
|
|
| LOGITS = (torch.cat(LOGITS).sigmoid().cpu() + torch.cat(LOGITS2).sigmoid().cpu()) / 2 |
| PREDS = LOGITS.sum(1).round().numpy() |
| return PREDS |
|
|
|
|
|
|
| def classify_images(im): |
| pred=predict_label(im) |
| s='Your submitted case has Prostate cancer of ISUP Grade '+str(pred) |
| return s |
|
|
|
|
|
|
| img=gr.Image(label="Upload Image", type="filepath") |
| label=gr.Label() |
| examples=["5.tiff","6.tiff","7.tiff"] |
|
|
| intf=gr.Interface(title="PCa Detection ProtoType",description="This is Protorype for our model prediction",fn=classify_images,inputs=img,outputs=label,examples=examples) |
| intf.launch(inline=False) |