| import torch |
| from torch import nn |
| import numpy as np |
| from PIL import Image |
| import json, os |
| import gradio as gr |
| import torchvision.transforms.functional as TF |
| from safetensors.torch import load_file |
| from matplotlib import cm |
|
|
| from models import get_model |
| from utils import resize_density_map, init_seeds |
|
|
|
|
| mean = (0.485, 0.456, 0.406) |
| std = (0.229, 0.224, 0.225) |
| alpha = 0.8 |
| init_seeds(42) |
|
|
| |
| |
| |
| truncation = 4 |
| reduction = 8 |
| granularity = "fine" |
| anchor_points = "average" |
|
|
| model_name = "clip_vit_l_14" |
| input_size = 224 |
|
|
| |
| prompt_type = "word" |
| num_vpt = 32 |
| vpt_drop = 0. |
| deep_vpt = True |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
| if truncation is None: |
| bins, anchor_points = None, None |
| else: |
| with open(os.path.join("configs", f"reduction_{reduction}.json"), "r") as f: |
| config = json.load(f)[str(truncation)]["nwpu"] |
| bins = config["bins"][granularity] |
| anchor_points = config["anchor_points"][granularity]["average"] if anchor_points == "average" else config["anchor_points"][granularity]["middle"] |
| bins = [(float(b[0]), float(b[1])) for b in bins] |
| anchor_points = [float(p) for p in anchor_points] |
|
|
|
|
| model = get_model( |
| backbone=model_name, |
| input_size=input_size, |
| reduction=reduction, |
| bins=bins, |
| anchor_points=anchor_points, |
| |
| prompt_type=prompt_type, |
| num_vpt=num_vpt, |
| vpt_drop=vpt_drop, |
| deep_vpt=deep_vpt |
| ) |
| weights_path = os.path.join("pre-trained weights", "CLIP-EBC-ViT-L-14-NWPU", "model.safetensors") |
| state_dict = load_file(weights_path) |
| new_state_dict = {} |
| for k, v in state_dict.items(): |
| new_state_dict[k.replace("model.", "")] = v |
| model.load_state_dict(new_state_dict) |
| model.to(device) |
| model.eval() |
|
|
|
|
| |
| |
| |
| |
| def transform(image: Image.Image): |
| assert isinstance(image, Image.Image), "Input must be a PIL Image" |
| image_tensor = TF.to_tensor(image) |
|
|
| image_height, image_width = image_tensor.shape[-2:] |
| if image_height < input_size or image_width < input_size: |
| |
| ratio = max(input_size / image_height, input_size / image_width) |
| new_height = int(image_height * ratio) + 1 |
| new_width = int(image_width * ratio) + 1 |
| image_tensor = TF.resize(image_tensor, (new_height, new_width), interpolation=TF.InterpolationMode.BICUBIC, antialias=True) |
|
|
| image_tensor = TF.normalize(image_tensor, mean=mean, std=std) |
| return image_tensor.unsqueeze(0) |
|
|
|
|
|
|
| |
| |
| |
| def predict(image: Image.Image): |
| """ |
| Given an input image, preprocess it, run the model to obtain a density map, |
| compute the total crowd count, and prepare the density map for display. |
| """ |
| |
| input_width, input_height = image.size |
| input_tensor = transform(image).to(device) |
| |
| with torch.no_grad(): |
| density_map = model(input_tensor) |
| total_count = density_map.sum().item() |
| resized_density_map = resize_density_map(density_map, (input_height, input_width)).cpu().squeeze().numpy() |
| |
| |
| eps = 1e-8 |
| density_map_norm = (resized_density_map - resized_density_map.min()) / (resized_density_map.max() - resized_density_map.min() + eps) |
| |
| |
| colormap = cm.get_cmap("jet") |
| |
| density_map_color = (colormap(density_map_norm) * 255).astype(np.uint8) |
| density_map_color_img = Image.fromarray(density_map_color).convert("RGBA") |
| |
| |
| image_rgba = image.convert("RGBA") |
| overlayed_image = Image.blend(image_rgba, density_map_color_img, alpha=alpha) |
| |
| return image, overlayed_image, f"Predicted Count: {total_count:.2f}" |
|
|
|
|
| |
| |
| |
| with gr.Blocks() as demo: |
| gr.Markdown("# Crowd Counting Demo") |
| gr.Markdown("Upload an image or select an example below to see the predicted crowd density map and total count.") |
| |
| with gr.Row(): |
| with gr.Column(): |
| input_img = gr.Image( |
| label="Input Image", |
| sources=["upload", "clipboard"], |
| type="pil", |
| ) |
| submit_btn = gr.Button("Predict") |
| with gr.Column(): |
| output_img = gr.Image(label="Predicted Density Map", type="pil") |
| output_text = gr.Textbox(label="Total Count") |
| |
| submit_btn.click(fn=predict, inputs=input_img, outputs=[input_img, output_img, output_text]) |
| |
| |
| gr.Examples( |
| examples=[ |
| ["example1.jpg"], |
| ["example2.jpg"] |
| ], |
| inputs=input_img, |
| label="Try an example" |
| ) |
|
|
| |
| demo.launch(share=True) |
|
|