"""Stage 1: Open-set object detection via VLM. Sends a detection-focused prompt to the VLM and parses structured bounding-box results from the response. """ from __future__ import annotations import json import re from dataclasses import dataclass, field from PIL import Image DETECTION_PROMPT = """\ Detect all objects in this construction site image. Target objects to detect: - Workers / persons - Construction machinery: excavator, crane, tower_crane, mobile_crane, \ dump_truck, cement_truck, wheel_loader, backhoe_loader, bulldozer, \ compactor, grader - Safety equipment worn by workers: hard_hat, safety_vest, harness - Structural elements: scaffolding, ladder, rebar, formwork - Environmental features: open_excavation, debris, exposed_wiring Return ONLY a JSON array. Each element must have: "class": lowercase_with_underscores, "bbox": [x_center, y_center, width, height] (normalized 0-1) Example: [{"class": "worker", "bbox": [0.30, 0.55, 0.08, 0.25]},\ {"class": "excavator", "bbox": [0.65, 0.40, 0.35, 0.30]}] If nothing is detected, return []. """ @dataclass class Detection: class_name: str bbox: list[float] # [cx, cy, w, h] normalized det_id: str = "" confidence: float = 1.0 @property def cx(self) -> float: return self.bbox[0] @property def cy(self) -> float: return self.bbox[1] @property def w(self) -> float: return self.bbox[2] @property def h(self) -> float: return self.bbox[3] @property def xyxy(self) -> list[float]: x1 = self.cx - self.w / 2 y1 = self.cy - self.h / 2 x2 = self.cx + self.w / 2 y2 = self.cy + self.h / 2 return [max(0, x1), max(0, y1), min(1, x2), min(1, y2)] def _extract_json(text: str) -> list[dict]: """Try multiple strategies to extract a JSON array from VLM response.""" # Strategy 1: look for ```json ... ``` block m = re.search(r"```(?:json)?\s*(\[.*?\])\s*```", text, re.DOTALL) if m: try: return json.loads(m.group(1)) except json.JSONDecodeError: pass # Strategy 2: find the outermost [ ... ] m = re.search(r"\[.*\]", text, re.DOTALL) if m: try: return json.loads(m.group(0)) except json.JSONDecodeError: pass # Strategy 3: line-by-line JSON objects results = [] for line in text.splitlines(): line = line.strip().rstrip(",") if line.startswith("{"): try: results.append(json.loads(line)) except json.JSONDecodeError: continue return results def _normalize_class(name: str) -> str: return name.lower().strip().replace(" ", "_").replace("-", "_") WORKER_ALIASES = {"worker", "person", "construction_worker", "laborer", "man", "woman"} MACHINERY_TYPES = { "excavator", "crane", "tower_crane", "mobile_crane", "dump_truck", "cement_truck", "truck", "wheel_loader", "backhoe_loader", "loader", "bulldozer", "dozer", "compactor", "roller", "grader", } def is_worker(class_name: str) -> bool: return _normalize_class(class_name) in WORKER_ALIASES def is_machinery(class_name: str) -> bool: return _normalize_class(class_name) in MACHINERY_TYPES def detect(vlm, image: Image.Image) -> list[Detection]: """Run open-set detection via VLM. Args: vlm: a VLMManager instance with a .generate(image, prompt) method. image: PIL image of the construction site. Returns: List of Detection objects, sorted left-to-right by center x. """ response = vlm.generate(image, DETECTION_PROMPT, max_new_tokens=1024) raw_items = _extract_json(response) detections: list[Detection] = [] for item in raw_items: cls = item.get("class", item.get("label", "")) conf = float(item.get("confidence", 1.0)) if not cls: continue # Qwen3.5 outputs bbox_2d as [x1,y1,x2,y2] in 0-1000 range. # Also handle our requested format: bbox as [cx,cy,w,h] in 0-1. raw = ( item.get("bbox_2d") or item.get("box_2d") or item.get("bbox") or item.get("bounding_box") or [] ) if len(raw) != 4: continue try: coords = [float(v) for v in raw] except (ValueError, TypeError): continue if "bbox_2d" in item or "box_2d" in item or max(coords) > 1: # xyxy format, 0-1000 range → convert to [cx, cy, w, h] 0-1 x1, y1, x2, y2 = [v / 1000.0 for v in coords] cx = (x1 + x2) / 2 cy = (y1 + y2) / 2 w = x2 - x1 h = y2 - y1 bbox = [cx, cy, w, h] else: bbox = coords if all(0 <= v <= 1 for v in bbox): detections.append(Detection( class_name=_normalize_class(cls), bbox=bbox, confidence=conf, )) # Sort by horizontal center (left to right) per the paper detections.sort(key=lambda d: d.cx) return detections