import gradio as gr from PIL import Image from pathlib import Path import os from torchvision.io.image import read_image from torchvision.models.detection import fasterrcnn_resnet50_fpn_v2, FasterRCNN_ResNet50_FPN_V2_Weights from torchvision.utils import draw_bounding_boxes from torchvision.transforms.functional import to_pil_image import torch weights = FasterRCNN_ResNet50_FPN_V2_Weights.DEFAULT model = fasterrcnn_resnet50_fpn_v2(weights=weights, box_score_thresh=0.9) #path = Path("object_detection_model.pth") #model = torch.load(path) from PIL import Image,ImageDraw from torchvision.io.image import read_image from torchvision.models.detection import fasterrcnn_resnet50_fpn_v2, FasterRCNN_ResNet50_FPN_V2_Weights from torchvision.utils import draw_bounding_boxes from torchvision import transforms from torchvision.transforms.functional import to_pil_image,to_tensor from pathlib import Path import torch # Step 1: Initiali1ze model with the best available weights weights = FasterRCNN_ResNet50_FPN_V2_Weights.DEFAULT model = fasterrcnn_resnet50_fpn_v2(weights=weights, box_score_thresh=0.9) preprocess = weights.transforms() def predict1(inp1): model.eval() batch = [preprocess(inp1)] # Step 4: Use the model and visualize the prediction prediction = model(batch)[0] labels = [weights.meta["categories"][i] for i in prediction["labels"]] draw = ImageDraw.Draw(inp1) z = prediction["boxes"] # Define the bounding box coordinates and size (adjust these values as needed) outline_color = "red" line_width = 2 # Draw the bounding box as a rectangle x,y,l,k = z[0][0].item(), z[0][1].item(), z[0][2].item(), z[0][3].item() draw.rectangle([x, y,l,k], outline=outline_color, width=line_width) return labels[0],inp1 # Get example filepaths in a list of lists examples_path = Path("examples/") example_list = [["examples/" + example] for example in os.listdir(examples_path)] title = "Object Detection Model๐Ÿ˜€๐Ÿ˜€" description = "A Object Detection model based on pretrained resnet neural network . It can detect car,airplane,bus,boat,train, bird, horse, suitcase, handbag etc." article = "Created by [Rounak Gera](https://github.com/rounak890)" demo = gr.Interface(fn=predict1, inputs=[gr.components.Image(type="pil", label="Image 1")], outputs= [gr.components.Label(num_top_classes = 4, label = "Predictions"), gr.components.Image(type = "pil",label = "predicted")], title = title, description = description, article = article, examples = example_list) demo.launch()