| import os
|
| import io
|
| import base64
|
| import gc
|
| import torch
|
| import numpy as np
|
| from fastapi import FastAPI, HTTPException
|
| from fastapi.responses import HTMLResponse
|
| from pydantic import BaseModel
|
| from PIL import Image
|
| from ultralytics import YOLO
|
| import supervision as sv
|
|
|
|
|
| os.environ["CUDA_VISIBLE_DEVICES"] = ""
|
| os.environ["OMP_NUM_THREADS"] = "4"
|
| torch.set_num_threads(4)
|
|
|
|
|
| from rfdetr import RFDETRSegPreview
|
|
|
| app = FastAPI()
|
|
|
|
|
| SEG_MODEL_PATH = "/tmp/checkpoint_best_total.pth"
|
| CLS_MODEL_PATH = "weights/yolo12_cls.pt"
|
|
|
| models = {"seg": None, "cls": None}
|
|
|
| def load_models():
|
| if models["seg"] is None:
|
|
|
| models["seg"] = RFDETRSegPreview(pretrain_weights=SEG_MODEL_PATH)
|
| models["seg"].optimize_for_inference()
|
| if models["cls"] is None:
|
|
|
| models["cls"] = YOLO(CLS_MODEL_PATH)
|
|
|
| class PredictionConfig(BaseModel):
|
| image: str
|
| seg_enabled: bool
|
| seg_conf: float
|
| seg_show_conf: bool
|
| cls_enabled: bool
|
| cls_show_conf: bool
|
| cls_show_label: bool
|
|
|
| @app.get("/", response_class=HTMLResponse)
|
| async def serve_ui():
|
| with open("index.html", "r") as f:
|
| return f.read()
|
|
|
| @app.post("/predict")
|
| async def predict(config: PredictionConfig):
|
| load_models()
|
|
|
| try:
|
|
|
| header, encoded = config.image.split(",", 1) if "," in config.image else (None, config.image)
|
| img_bytes = base64.b64decode(encoded)
|
| original_img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
|
|
|
|
|
| detections = models["seg"].predict(original_img, threshold=config.seg_conf)
|
|
|
| if len(detections) == 0:
|
| return {"annotated": config.image, "count": 0}
|
|
|
|
|
| labels = []
|
| if config.cls_enabled:
|
| for i in range(len(detections.xyxy)):
|
| x1, y1, x2, y2 = detections.xyxy[i].astype(int)
|
| crop = original_img.crop((x1, y1, x2, y2))
|
| cls_res = models["cls"](crop)[0]
|
|
|
| top1_idx = cls_res.probs.top1
|
| name = cls_res.names[top1_idx]
|
| conf = float(cls_res.probs.top1conf)
|
|
|
| label_str = ""
|
| if config.cls_show_label: label_str += f"{name} "
|
| if config.cls_show_conf: label_str += f"{conf:.2f}"
|
| labels.append(label_str.strip())
|
| else:
|
|
|
| for conf in detections.confidence:
|
| labels.append(f"Leaf {conf:.2f}" if config.seg_show_conf else "Leaf")
|
|
|
|
|
| palette = sv.ColorPalette.from_hex(["#EA782D", "#FF7A5A", "#FFA382"])
|
| mask_annotator = sv.MaskAnnotator(color=palette)
|
| label_annotator = sv.LabelAnnotator(
|
| color=palette,
|
| text_position=sv.Position.CENTER_OF_MASS,
|
| text_scale=0.5
|
| )
|
|
|
| annotated_img = original_img.copy()
|
| if config.seg_enabled:
|
| annotated_img = mask_annotator.annotate(scene=annotated_img, detections=detections)
|
|
|
| annotated_img = label_annotator.annotate(scene=annotated_img, detections=detections, labels=labels)
|
|
|
|
|
| buffered = io.BytesIO()
|
| annotated_img.save(buffered, format="PNG")
|
| encoded_res = base64.b64encode(buffered.getvalue()).decode("ascii")
|
|
|
| return {
|
| "annotated": f"data:image/png;base64,{encoded_res}",
|
| "count": len(detections)
|
| }
|
|
|
| except Exception as e:
|
| raise HTTPException(status_code=500, detail=str(e))
|
| finally:
|
| gc.collect()
|
|
|
| if __name__ == "__main__":
|
| import uvicorn
|
| uvicorn.run(app, host="0.0.0.0", port=7860) |