2024Lee's picture
Upload folder using huggingface_hub
c58b8ac verified
"""Stage 1: Open-set object detection via VLM.
Sends a detection-focused prompt to the VLM and parses structured
bounding-box results from the response.
"""
from __future__ import annotations
import json
import re
from dataclasses import dataclass, field
from PIL import Image
DETECTION_PROMPT = """\
Detect all objects in this construction site image.
Target objects to detect:
- Workers / persons
- Construction machinery: excavator, crane, tower_crane, mobile_crane, \
dump_truck, cement_truck, wheel_loader, backhoe_loader, bulldozer, \
compactor, grader
- Safety equipment worn by workers: hard_hat, safety_vest, harness
- Structural elements: scaffolding, ladder, rebar, formwork
- Environmental features: open_excavation, debris, exposed_wiring
Return ONLY a JSON array. Each element must have:
"class": lowercase_with_underscores,
"bbox": [x_center, y_center, width, height] (normalized 0-1)
Example:
[{"class": "worker", "bbox": [0.30, 0.55, 0.08, 0.25]},\
{"class": "excavator", "bbox": [0.65, 0.40, 0.35, 0.30]}]
If nothing is detected, return [].
"""
@dataclass
class Detection:
class_name: str
bbox: list[float] # [cx, cy, w, h] normalized
det_id: str = ""
confidence: float = 1.0
@property
def cx(self) -> float:
return self.bbox[0]
@property
def cy(self) -> float:
return self.bbox[1]
@property
def w(self) -> float:
return self.bbox[2]
@property
def h(self) -> float:
return self.bbox[3]
@property
def xyxy(self) -> list[float]:
x1 = self.cx - self.w / 2
y1 = self.cy - self.h / 2
x2 = self.cx + self.w / 2
y2 = self.cy + self.h / 2
return [max(0, x1), max(0, y1), min(1, x2), min(1, y2)]
def _extract_json(text: str) -> list[dict]:
"""Try multiple strategies to extract a JSON array from VLM response."""
# Strategy 1: look for ```json ... ``` block
m = re.search(r"```(?:json)?\s*(\[.*?\])\s*```", text, re.DOTALL)
if m:
try:
return json.loads(m.group(1))
except json.JSONDecodeError:
pass
# Strategy 2: find the outermost [ ... ]
m = re.search(r"\[.*\]", text, re.DOTALL)
if m:
try:
return json.loads(m.group(0))
except json.JSONDecodeError:
pass
# Strategy 3: line-by-line JSON objects
results = []
for line in text.splitlines():
line = line.strip().rstrip(",")
if line.startswith("{"):
try:
results.append(json.loads(line))
except json.JSONDecodeError:
continue
return results
def _normalize_class(name: str) -> str:
return name.lower().strip().replace(" ", "_").replace("-", "_")
WORKER_ALIASES = {"worker", "person", "construction_worker", "laborer", "man", "woman"}
MACHINERY_TYPES = {
"excavator", "crane", "tower_crane", "mobile_crane",
"dump_truck", "cement_truck", "truck",
"wheel_loader", "backhoe_loader", "loader",
"bulldozer", "dozer", "compactor", "roller", "grader",
}
def is_worker(class_name: str) -> bool:
return _normalize_class(class_name) in WORKER_ALIASES
def is_machinery(class_name: str) -> bool:
return _normalize_class(class_name) in MACHINERY_TYPES
def detect(vlm, image: Image.Image) -> list[Detection]:
"""Run open-set detection via VLM.
Args:
vlm: a VLMManager instance with a .generate(image, prompt) method.
image: PIL image of the construction site.
Returns:
List of Detection objects, sorted left-to-right by center x.
"""
response = vlm.generate(image, DETECTION_PROMPT, max_new_tokens=1024)
raw_items = _extract_json(response)
detections: list[Detection] = []
for item in raw_items:
cls = item.get("class", item.get("label", ""))
conf = float(item.get("confidence", 1.0))
if not cls:
continue
# Qwen3.5 outputs bbox_2d as [x1,y1,x2,y2] in 0-1000 range.
# Also handle our requested format: bbox as [cx,cy,w,h] in 0-1.
raw = (
item.get("bbox_2d")
or item.get("box_2d")
or item.get("bbox")
or item.get("bounding_box")
or []
)
if len(raw) != 4:
continue
try:
coords = [float(v) for v in raw]
except (ValueError, TypeError):
continue
if "bbox_2d" in item or "box_2d" in item or max(coords) > 1:
# xyxy format, 0-1000 range → convert to [cx, cy, w, h] 0-1
x1, y1, x2, y2 = [v / 1000.0 for v in coords]
cx = (x1 + x2) / 2
cy = (y1 + y2) / 2
w = x2 - x1
h = y2 - y1
bbox = [cx, cy, w, h]
else:
bbox = coords
if all(0 <= v <= 1 for v in bbox):
detections.append(Detection(
class_name=_normalize_class(cls),
bbox=bbox,
confidence=conf,
))
# Sort by horizontal center (left to right) per the paper
detections.sort(key=lambda d: d.cx)
return detections