SynLayers commited on
Commit
1240669
·
verified ·
1 Parent(s): 31de91e

Upload demo/infer/vlm_bbox_inference.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. demo/infer/vlm_bbox_inference.py +235 -0
demo/infer/vlm_bbox_inference.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Shared utility module for VLM bbox-only inference.
4
+
5
+ This module provides:
6
+ - model and processor loading
7
+ - prompts for two coordinate conventions
8
+ - parsing utilities for bbox-only outputs
9
+
10
+ The model output is expected to be either:
11
+ [[x0, y0, x1, y1], ...]
12
+ or:
13
+ [[x_left, y_top, x_right, y_bottom], ...]
14
+
15
+ No caption is generated in this module.
16
+ """
17
+
18
+ import ast
19
+ import re
20
+ from pathlib import Path
21
+
22
+ import torch
23
+ from transformers import AutoProcessor
24
+
25
+
26
+ # Bottom-left origin, y-axis upward.
27
+ BBOX_PROMPT_BOTTOM_LEFT = (
28
+ "<image>This image is 1024 pixels in width and 1024 pixels in height. "
29
+ "The coordinate origin is at the bottom-left corner: x increases to the right, y increases upward. "
30
+ "Detect all objects (layers) in this image. "
31
+ "Output bounding boxes as a list of [x0, y0, x1, y1] where x0=left, y0=bottom, x1=right, y1=top (pixel coordinates). "
32
+ "Output only the list, e.g. [[x0,y0,x1,y1], ...], no other text."
33
+ )
34
+
35
+ # Top-left origin, y-axis downward.
36
+ BBOX_PROMPT_TOP_LEFT = (
37
+ "<image>This image is 1024 pixels in width and 1024 pixels in height. "
38
+ "The coordinate origin is at the top-left corner of the image: x increases to the right, y increases downward. "
39
+ "Detect all objects (layers) in this image. "
40
+ "Output bounding boxes as a list of [x_left, y_top, x_right, y_bottom] in pixel coordinates (top-left origin, y downward). "
41
+ "Output only the list, e.g. [[x_left,y_top,x_right,y_bottom], ...], no other text."
42
+ )
43
+
44
+ # Default prompt used by the generic inference API.
45
+ BBOX_PROMPT = BBOX_PROMPT_TOP_LEFT
46
+ BBOX_SYSTEM = None
47
+
48
+
49
+ def get_model_and_processor(model_path: str, device_map: str = "auto"):
50
+ try:
51
+ from transformers import Qwen3VLForConditionalGeneration
52
+ model_cls = Qwen3VLForConditionalGeneration
53
+ except ImportError:
54
+ try:
55
+ from transformers import Qwen2_5_VLForConditionalGeneration
56
+ model_cls = Qwen2_5_VLForConditionalGeneration
57
+ except ImportError:
58
+ from transformers import AutoModel
59
+ model_cls = AutoModel
60
+
61
+ base_2b = "Qwen/Qwen3-VL-2B-Instruct"
62
+ base_8b = "Qwen/Qwen3-VL-8B-Instruct"
63
+ base_name = base_8b if "8b" in model_path.lower() or "8B" in model_path else base_2b
64
+
65
+ dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
66
+ model_dir = Path(model_path)
67
+
68
+ def load_config(*sources):
69
+ from transformers import AutoConfig
70
+
71
+ for source in sources:
72
+ if not source:
73
+ continue
74
+ try:
75
+ config = AutoConfig.from_pretrained(source, trust_remote_code=True)
76
+ if not hasattr(config, "rope_scaling") or config.rope_scaling is None:
77
+ config.rope_scaling = {}
78
+ return config
79
+ except Exception:
80
+ continue
81
+ return None
82
+
83
+ if model_dir.is_dir() and (model_dir / "adapter_config.json").exists():
84
+ from peft import PeftConfig, PeftModel
85
+
86
+ peft_config = PeftConfig.from_pretrained(str(model_dir))
87
+ base_name = peft_config.base_model_name_or_path or base_name
88
+ config = load_config(base_name, str(model_dir))
89
+
90
+ model = model_cls.from_pretrained(
91
+ base_name,
92
+ config=config,
93
+ torch_dtype=dtype,
94
+ device_map=device_map,
95
+ trust_remote_code=True,
96
+ )
97
+ model = PeftModel.from_pretrained(model, str(model_dir))
98
+
99
+ try:
100
+ processor = AutoProcessor.from_pretrained(str(model_dir), trust_remote_code=True)
101
+ except Exception:
102
+ processor = AutoProcessor.from_pretrained(base_name, trust_remote_code=True)
103
+ else:
104
+ config = load_config(str(model_dir) if model_dir.exists() else None, model_path, base_name)
105
+ model = model_cls.from_pretrained(
106
+ model_path,
107
+ config=config,
108
+ torch_dtype=dtype,
109
+ device_map=device_map,
110
+ trust_remote_code=True,
111
+ )
112
+
113
+ try:
114
+ processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
115
+ except Exception:
116
+ processor = AutoProcessor.from_pretrained(base_name, trust_remote_code=True)
117
+
118
+ return model, processor
119
+
120
+
121
+ def parse_bbox_output(text: str):
122
+ """
123
+ Parse bbox lists from model output.
124
+
125
+ The parser extracts all standalone [a, b, c, d] patterns using regex,
126
+ ignoring outer brackets and extra text. Duplicate boxes are removed to
127
+ reduce the impact of repeated model outputs.
128
+ """
129
+ text = (text or "").strip()
130
+
131
+ # Match standalone boxes such as [102, 611, 511, 1023].
132
+ # Both integers and floating-point numbers are supported.
133
+ pattern = (
134
+ r"\[\s*-?\d+(?:\.\d+)?\s*,"
135
+ r"\s*-?\d+(?:\.\d+)?\s*,"
136
+ r"\s*-?\d+(?:\.\d+)?\s*,"
137
+ r"\s*-?\d+(?:\.\d+)?\s*\]"
138
+ )
139
+ matches = re.findall(pattern, text)
140
+
141
+ parsed_boxes = []
142
+ for match_str in matches:
143
+ try:
144
+ box = ast.literal_eval(match_str)
145
+ if isinstance(box, list) and len(box) == 4:
146
+ parsed_boxes.append(box)
147
+ except (ValueError, SyntaxError):
148
+ continue
149
+
150
+ # Remove duplicate boxes to avoid repeated outputs.
151
+ unique_boxes = []
152
+ seen = set()
153
+
154
+ for b in parsed_boxes:
155
+ key = tuple(float(x) for x in b)
156
+ if key not in seen:
157
+ seen.add(key)
158
+ unique_boxes.append(b)
159
+
160
+ return unique_boxes
161
+
162
+
163
+ def infer_bboxes(image_path: str, model, processor, *, prompt: str, max_new_tokens: int = 512):
164
+ """
165
+ Run bbox inference for a single image.
166
+
167
+ Returns:
168
+ A list of boxes, where each box is [a, b, c, d].
169
+ The coordinate meaning depends on the prompt convention.
170
+ """
171
+ path = Path(image_path)
172
+ if not path.exists():
173
+ return []
174
+
175
+ content = [
176
+ {"type": "image", "image": str(path.absolute())},
177
+ {"type": "text", "text": prompt},
178
+ ]
179
+
180
+ messages = [{"role": "user", "content": content}]
181
+
182
+ inputs = processor.apply_chat_template(
183
+ messages,
184
+ tokenize=True,
185
+ add_generation_prompt=True,
186
+ return_dict=True,
187
+ return_tensors="pt",
188
+ )
189
+
190
+ inputs = {k: v.to(model.device) if hasattr(v, "to") else v for k, v in inputs.items()}
191
+ inputs.pop("token_type_ids", None)
192
+
193
+ with torch.no_grad():
194
+ generated = model.generate(
195
+ **inputs,
196
+ max_new_tokens=max_new_tokens,
197
+ do_sample=True,
198
+ temperature=0.1,
199
+ repetition_penalty=1.1,
200
+ pad_token_id=processor.tokenizer.pad_token_id or processor.tokenizer.eos_token_id,
201
+ )
202
+
203
+ input_len = inputs["input_ids"].shape[1]
204
+ output_ids = generated[:, input_len:]
205
+
206
+ output_text = processor.batch_decode(
207
+ output_ids,
208
+ skip_special_tokens=True,
209
+ clean_up_tokenization_spaces=True,
210
+ )
211
+
212
+ raw = (output_text[0] or "").strip()
213
+ boxes = parse_bbox_output(raw)
214
+
215
+ result = []
216
+ for b in boxes:
217
+ if isinstance(b, (list, tuple)) and len(b) >= 4:
218
+ result.append([float(b[0]), float(b[1]), float(b[2]), float(b[3])])
219
+
220
+ return result
221
+
222
+
223
+ def detect_objects(image_path: str, model, processor, *, prompt=None, system=None, max_new_tokens=512):
224
+ """
225
+ Compatibility wrapper using the default bbox prompt.
226
+
227
+ The `system` argument is kept for compatibility with older call sites.
228
+ """
229
+ return infer_bboxes(
230
+ image_path,
231
+ model,
232
+ processor,
233
+ prompt=prompt or BBOX_PROMPT,
234
+ max_new_tokens=max_new_tokens,
235
+ )