rgera's picture
desc added
1463b4a
import gradio as gr
from PIL import Image
from pathlib import Path
import os
from torchvision.io.image import read_image
from torchvision.models.detection import fasterrcnn_resnet50_fpn_v2, FasterRCNN_ResNet50_FPN_V2_Weights
from torchvision.utils import draw_bounding_boxes
from torchvision.transforms.functional import to_pil_image
import torch
weights = FasterRCNN_ResNet50_FPN_V2_Weights.DEFAULT
model = fasterrcnn_resnet50_fpn_v2(weights=weights, box_score_thresh=0.9)
#path = Path("object_detection_model.pth")
#model = torch.load(path)
from PIL import Image,ImageDraw
from torchvision.io.image import read_image
from torchvision.models.detection import fasterrcnn_resnet50_fpn_v2, FasterRCNN_ResNet50_FPN_V2_Weights
from torchvision.utils import draw_bounding_boxes
from torchvision import transforms
from torchvision.transforms.functional import to_pil_image,to_tensor
from pathlib import Path
import torch
# Step 1: Initiali1ze model with the best available weights
weights = FasterRCNN_ResNet50_FPN_V2_Weights.DEFAULT
model = fasterrcnn_resnet50_fpn_v2(weights=weights, box_score_thresh=0.9)
preprocess = weights.transforms()
def predict1(inp1):
model.eval()
batch = [preprocess(inp1)]
# Step 4: Use the model and visualize the prediction
prediction = model(batch)[0]
labels = [weights.meta["categories"][i] for i in prediction["labels"]]
draw = ImageDraw.Draw(inp1)
z = prediction["boxes"]
# Define the bounding box coordinates and size (adjust these values as needed)
outline_color = "red"
line_width = 2
# Draw the bounding box as a rectangle
x,y,l,k = z[0][0].item(), z[0][1].item(), z[0][2].item(), z[0][3].item()
draw.rectangle([x, y,l,k], outline=outline_color, width=line_width)
return labels[0],inp1
# Get example filepaths in a list of lists
examples_path = Path("examples/")
example_list = [["examples/" + example] for example in os.listdir(examples_path)]
title = "Object Detection Model😀😀"
description = "A Object Detection model based on pretrained resnet neural network . It can detect car,airplane,bus,boat,train, bird, horse, suitcase, handbag etc."
article = "Created by [Rounak Gera](https://github.com/rounak890)"
demo = gr.Interface(fn=predict1, inputs=[gr.components.Image(type="pil", label="Image 1")],
outputs= [gr.components.Label(num_top_classes = 4, label = "Predictions"),
gr.components.Image(type = "pil",label = "predicted")],
title = title,
description = description,
article = article,
examples = example_list)
demo.launch()