| import numpy as np |
| import cv2 |
| from ultralytics import FastSAM |
| import torch |
| import gc |
|
|
| |
| MODELS = { |
| "small": "./models/FastSAM-s.pt", |
| "large": "./models/FastSAM-x.pt" |
| } |
|
|
| def clear_gpu_memory(): |
| """ |
| 清理GPU显存 |
| """ |
| gc.collect() |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
| torch.cuda.ipc_collect() |
|
|
| def get_model(model_size: str = "large"): |
| """ |
| 获取指定大小的模型 |
| """ |
| if model_size not in MODELS: |
| raise ValueError(f"Invalid model size. Available sizes: {list(MODELS.keys())}") |
| |
| try: |
| return FastSAM(MODELS[model_size]) |
| except Exception as e: |
| raise RuntimeError(f"Failed to load model: {str(e)}") |
|
|
| def mask_to_points(mask: np.ndarray) -> list: |
| """ |
| Convert mask to a list of contour points |
| |
| Args: |
| mask: 2D numpy array representing the mask |
| |
| Returns: |
| list: Flattened list of points [x1, y1, x2, y2, ...] |
| """ |
| |
| mask_uint8 = mask.astype(np.uint8) * 255 |
| |
| contours, _ = cv2.findContours(mask_uint8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) |
| |
| if not contours: |
| return [] |
| |
| |
| contour = max(contours, key=cv2.contourArea) |
| |
| points = [] |
| for point in contour: |
| points.extend([float(point[0][0]), float(point[0][1])]) |
| return points |
|
|
| def segment_image_with_prompt( |
| image: np.ndarray, |
| model_size: str = "large", |
| conf: float = 0.4, |
| iou: float = 0.9, |
| bboxes: list = None, |
| points: list = None, |
| labels: list = None, |
| texts: str = None |
| ): |
| """ |
| 带提示的图像分割函数 |
| |
| Args: |
| image: 输入图像 (numpy.ndarray) |
| model_size: 模型大小 ("small" 或 "large") |
| conf: 置信度阈值 |
| iou: IoU 阈值 |
| bboxes: 边界框列表 [x1, y1, x2, y2, ...] |
| points: 点列表 [[x1, y1], [x2, y2], ...] |
| labels: 点对应的标签列表 [0, 1, ...] |
| texts: 文本提示 |
| """ |
| try: |
| if image is None: |
| raise ValueError("Invalid image input") |
| |
| |
| model = get_model(model_size) |
| |
| |
| model_args = { |
| "device": "cpu", |
| "retina_masks": True, |
| "conf": conf, |
| "iou": iou |
| } |
| |
| |
| if bboxes is not None: |
| model_args["bboxes"] = bboxes |
| if points is not None and labels is not None: |
| model_args["points"] = points |
| model_args["labels"] = labels |
| if texts is not None: |
| model_args["texts"] = texts |
| |
| |
| everything_results = model(image, **model_args) |
| |
| |
| segments = [] |
| if everything_results and len(everything_results) > 0: |
| result = everything_results[0] |
| if hasattr(result, 'masks') and result.masks is not None: |
| masks = result.masks.data.cpu().numpy() |
| |
| for mask in masks: |
| points = mask_to_points(mask) |
| if points: |
| segments.append(points) |
| |
| |
| del model |
| del everything_results |
| if hasattr(result, 'masks'): |
| del result.masks |
| del result |
| |
| |
| return { |
| "total_segments": len(segments), |
| "segments": segments |
| } |
| except Exception as e: |
| |
| clear_gpu_memory() |
| raise RuntimeError(f"Error processing image: {str(e)}") |
|
|