| from fastapi import FastAPI, File, UploadFile, HTTPException |
| from fastapi.responses import JSONResponse |
| from pydantic import BaseModel |
| import base64 |
| import io |
| import os |
| from PIL import Image |
| import torch |
| import numpy as np |
| from ultralytics import YOLO |
| from transformers import AutoProcessor, AutoModelForCausalLM |
|
|
| |
| if not os.path.exists("weights/icon_detect"): |
| os.makedirs("weights/icon_detect") |
|
|
| |
| try: |
| |
| yolo_model = YOLO("weights/icon_detect/best.pt").to("cuda") |
| except Exception as e: |
| print(f"Error loading YOLO model: {e}") |
| yolo_model = YOLO("weights/icon_detect/best.pt") |
|
|
| |
| try: |
| processor = AutoProcessor.from_pretrained("microsoft/Florence-2-base", trust_remote_code=True) |
| model = AutoModelForCausalLM.from_pretrained( |
| "microsoft/OmniParser", |
| torch_dtype=torch.float16, |
| trust_remote_code=True |
| ).to("cuda") |
| except Exception as e: |
| print(f"Error loading caption model: {e}") |
| processor = AutoProcessor.from_pretrained("microsoft/Florence-2-base", trust_remote_code=True) |
| model = AutoModelForCausalLM.from_pretrained( |
| "microsoft/OmniParser", |
| torch_dtype=torch.float16, |
| trust_remote_code=True |
| ) |
|
|
| caption_model_processor = {"processor": processor, "model": model} |
| print("Finished loading models!") |
|
|
| |
| app = FastAPI() |
|
|
| |
| class ProcessResponse(BaseModel): |
| image: str |
| parsed_content_list: str |
| label_coordinates: str |
|
|
| |
| def process( |
| image_input: Image.Image, box_threshold: float, iou_threshold: float |
| ) -> ProcessResponse: |
| image_save_path = "imgs/saved_image_demo.png" |
| image_input.save(image_save_path) |
| image = Image.open(image_save_path) |
| |
| |
| box_overlay_ratio = image.size[0] / 3200 |
| draw_bbox_config = { |
| "text_scale": 0.8 * box_overlay_ratio, |
| "text_thickness": max(int(2 * box_overlay_ratio), 1), |
| "text_padding": max(int(3 * box_overlay_ratio), 1), |
| "thickness": max(int(3 * box_overlay_ratio), 1), |
| } |
|
|
| |
| try: |
| ocr_bbox_rslt, is_goal_filtered = check_ocr_box( |
| image_save_path, |
| display_img=False, |
| output_bb_format="xyxy", |
| goal_filtering=None, |
| easyocr_args={"paragraph": False, "text_threshold": 0.9}, |
| use_paddleocr=True, |
| ) |
| text, ocr_bbox = ocr_bbox_rslt |
| except Exception as e: |
| raise HTTPException(status_code=500, detail=f"OCR processing failed: {e}") |
| |
| |
| try: |
| dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img( |
| image_save_path, |
| yolo_model, |
| BOX_TRESHOLD=box_threshold, |
| output_coord_in_ratio=True, |
| ocr_bbox=ocr_bbox, |
| draw_bbox_config=draw_bbox_config, |
| caption_model_processor=caption_model_processor, |
| ocr_text=text, |
| iou_threshold=iou_threshold, |
| ) |
| except Exception as e: |
| raise HTTPException(status_code=500, detail=f"YOLO or caption model inference failed: {e}") |
| |
| |
| image = Image.open(io.BytesIO(base64.b64decode(dino_labled_img))) |
| parsed_content_list_str = "\n".join(parsed_content_list) |
| |
| |
| buffered = io.BytesIO() |
| image.save(buffered, format="PNG") |
| img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") |
|
|
| return ProcessResponse( |
| image=img_str, |
| parsed_content_list=str(parsed_content_list_str), |
| label_coordinates=str(label_coordinates), |
| ) |
|
|
| |
| @app.post("/process_image", response_model=ProcessResponse) |
| async def process_image( |
| image_file: UploadFile = File(...), |
| box_threshold: float = 0.05, |
| iou_threshold: float = 0.1, |
| ): |
| try: |
| contents = await image_file.read() |
| image_input = Image.open(io.BytesIO(contents)).convert("RGB") |
| except Exception as e: |
| raise HTTPException(status_code=400, detail=f"Invalid image file: {e}") |
|
|
| |
| response = process(image_input, box_threshold, iou_threshold) |
| return response |
|
|