| """Stage 1: Open-set object detection via VLM. |
| |
| Sends a detection-focused prompt to the VLM and parses structured |
| bounding-box results from the response. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import json |
| import re |
| from dataclasses import dataclass, field |
|
|
| from PIL import Image |
|
|
| DETECTION_PROMPT = """\ |
| Detect all objects in this construction site image. |
| |
| Target objects to detect: |
| - Workers / persons |
| - Construction machinery: excavator, crane, tower_crane, mobile_crane, \ |
| dump_truck, cement_truck, wheel_loader, backhoe_loader, bulldozer, \ |
| compactor, grader |
| - Safety equipment worn by workers: hard_hat, safety_vest, harness |
| - Structural elements: scaffolding, ladder, rebar, formwork |
| - Environmental features: open_excavation, debris, exposed_wiring |
| |
| Return ONLY a JSON array. Each element must have: |
| "class": lowercase_with_underscores, |
| "bbox": [x_center, y_center, width, height] (normalized 0-1) |
| |
| Example: |
| [{"class": "worker", "bbox": [0.30, 0.55, 0.08, 0.25]},\ |
| {"class": "excavator", "bbox": [0.65, 0.40, 0.35, 0.30]}] |
| |
| If nothing is detected, return []. |
| """ |
|
|
|
|
| @dataclass |
| class Detection: |
| class_name: str |
| bbox: list[float] |
| det_id: str = "" |
| confidence: float = 1.0 |
|
|
| @property |
| def cx(self) -> float: |
| return self.bbox[0] |
|
|
| @property |
| def cy(self) -> float: |
| return self.bbox[1] |
|
|
| @property |
| def w(self) -> float: |
| return self.bbox[2] |
|
|
| @property |
| def h(self) -> float: |
| return self.bbox[3] |
|
|
| @property |
| def xyxy(self) -> list[float]: |
| x1 = self.cx - self.w / 2 |
| y1 = self.cy - self.h / 2 |
| x2 = self.cx + self.w / 2 |
| y2 = self.cy + self.h / 2 |
| return [max(0, x1), max(0, y1), min(1, x2), min(1, y2)] |
|
|
|
|
| def _extract_json(text: str) -> list[dict]: |
| """Try multiple strategies to extract a JSON array from VLM response.""" |
| |
| m = re.search(r"```(?:json)?\s*(\[.*?\])\s*```", text, re.DOTALL) |
| if m: |
| try: |
| return json.loads(m.group(1)) |
| except json.JSONDecodeError: |
| pass |
|
|
| |
| m = re.search(r"\[.*\]", text, re.DOTALL) |
| if m: |
| try: |
| return json.loads(m.group(0)) |
| except json.JSONDecodeError: |
| pass |
|
|
| |
| results = [] |
| for line in text.splitlines(): |
| line = line.strip().rstrip(",") |
| if line.startswith("{"): |
| try: |
| results.append(json.loads(line)) |
| except json.JSONDecodeError: |
| continue |
| return results |
|
|
|
|
| def _normalize_class(name: str) -> str: |
| return name.lower().strip().replace(" ", "_").replace("-", "_") |
|
|
|
|
| WORKER_ALIASES = {"worker", "person", "construction_worker", "laborer", "man", "woman"} |
| MACHINERY_TYPES = { |
| "excavator", "crane", "tower_crane", "mobile_crane", |
| "dump_truck", "cement_truck", "truck", |
| "wheel_loader", "backhoe_loader", "loader", |
| "bulldozer", "dozer", "compactor", "roller", "grader", |
| } |
|
|
|
|
| def is_worker(class_name: str) -> bool: |
| return _normalize_class(class_name) in WORKER_ALIASES |
|
|
|
|
| def is_machinery(class_name: str) -> bool: |
| return _normalize_class(class_name) in MACHINERY_TYPES |
|
|
|
|
| def detect(vlm, image: Image.Image) -> list[Detection]: |
| """Run open-set detection via VLM. |
| |
| Args: |
| vlm: a VLMManager instance with a .generate(image, prompt) method. |
| image: PIL image of the construction site. |
| |
| Returns: |
| List of Detection objects, sorted left-to-right by center x. |
| """ |
| response = vlm.generate(image, DETECTION_PROMPT, max_new_tokens=1024) |
| raw_items = _extract_json(response) |
|
|
| detections: list[Detection] = [] |
| for item in raw_items: |
| cls = item.get("class", item.get("label", "")) |
| conf = float(item.get("confidence", 1.0)) |
| if not cls: |
| continue |
|
|
| |
| |
| raw = ( |
| item.get("bbox_2d") |
| or item.get("box_2d") |
| or item.get("bbox") |
| or item.get("bounding_box") |
| or [] |
| ) |
| if len(raw) != 4: |
| continue |
| try: |
| coords = [float(v) for v in raw] |
| except (ValueError, TypeError): |
| continue |
|
|
| if "bbox_2d" in item or "box_2d" in item or max(coords) > 1: |
| |
| x1, y1, x2, y2 = [v / 1000.0 for v in coords] |
| cx = (x1 + x2) / 2 |
| cy = (y1 + y2) / 2 |
| w = x2 - x1 |
| h = y2 - y1 |
| bbox = [cx, cy, w, h] |
| else: |
| bbox = coords |
|
|
| if all(0 <= v <= 1 for v in bbox): |
| detections.append(Detection( |
| class_name=_normalize_class(cls), |
| bbox=bbox, |
| confidence=conf, |
| )) |
|
|
| |
| detections.sort(key=lambda d: d.cx) |
| return detections |
|
|