| import os |
| |
| |
| import matplotlib.pyplot as plt |
| import gradio as gr |
| import cv2 |
| import numpy as np |
| import torch |
| from mobile_sam import SamAutomaticMaskGenerator, SamPredictor, sam_model_registry |
| from PIL import ImageDraw,Image |
| from utils.tools import box_prompt, format_results, point_prompt |
| from utils.tools_gradio import fast_process |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| |
| sam_checkpoint = r"F:\zht\code\MobileSAM-master\weights\mobile_sam.pt" |
| model_type = "vit_t" |
| mobile_sam = sam_model_registry[model_type](checkpoint=sam_checkpoint) |
| mobile_sam = mobile_sam.to(device=device) |
| mobile_sam.eval() |
|
|
| mask_generator = SamAutomaticMaskGenerator(mobile_sam) |
| predictor = SamPredictor(mobile_sam) |
|
|
| |
|
|
| @torch.no_grad() |
| def segment_with_boxs( |
| image, |
| input_size=1024, |
| better_quality=False, |
| withContours=True, |
| use_retina=True, |
| mask_random_color=True, |
| ): |
| global global_points |
| global global_point_label |
|
|
| input_size = int(input_size) |
| w, h = image.size |
| scale = input_size / max(w, h) |
| new_w = int(w * scale) |
| new_h = int(h * scale) |
|
|
| image = image.resize((new_w, new_h)) |
| |
| scaled_points = np.array( |
| [[int(x * scale) for x in point] for point in global_points] |
| ) |
| print("nnnnnnnnnnnnnnnnnnnnnnnnnnnnn00nnnnn",scaled_points) |
| scaled_point_label = np.array(global_point_label) |
|
|
| nd_image = np.array(image) |
| print("mmmmmmm0mmmm",nd_image.shape) |
| predictor.set_image(nd_image) |
| masks, scores, logits = predictor.predict( |
| point_coords=scaled_points, |
| point_labels=scaled_point_label, |
| multimask_output=True, |
| ) |
|
|
| results = format_results(masks, scores, logits, 0) |
| print("mmmmmmmmmmmmmmmm2222m",len(results)) |
| annotations, _ = point_prompt( |
| results, scaled_points, scaled_point_label, new_h, new_w |
| ) |
| annotations = np.array([annotations]) |
| |
| plt.imshow(annotations[0], cmap='viridis') |
| plt.colorbar() |
| plt.savefig(r'F:\zht\code\2.png') |
| plt.show() |
|
|
| fig = fast_process( |
| annotations=annotations, |
| image=image, |
| device=device, |
| scale=(1024 // input_size), |
| better_quality=better_quality, |
| mask_random_color=mask_random_color, |
| bbox=None, |
| use_retina=use_retina, |
| withContours=withContours, |
| ) |
| global_points = [] |
| global_point_label = [] |
| return fig, image |
|
|
| |
| if __name__ == "__main__": |
| path = r"F:\zht\code\MobileSAM-master\app\assets\05.jpg" |
| image1 = Image.open(path) |
| |
| print(image1.size) |
| |
| global_points = [[324,740,1448,1192]] |
| global_point_label = [1] |
| segment_with_boxs( |
| image1, |
| input_size=1024, |
| better_quality=False, |
| withContours=True, |
| use_retina=True, |
| mask_random_color=True, |
| ) |