SynLayers commited on
Commit
1a01e4b
·
verified ·
1 Parent(s): c8f74f8

Upload demo/infer/run_caption_bbox_infer.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. demo/infer/run_caption_bbox_infer.py +287 -0
demo/infer/run_caption_bbox_infer.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Run whole-caption + bbox inference and save portable JSONL results."""
3
+
4
+ import os
5
+ import json
6
+ import re
7
+ from pathlib import Path
8
+
9
+ import torch
10
+ from PIL import Image, ImageDraw
11
+
12
+ try:
13
+ from demo.infer.vlm_bbox_inference import (
14
+ get_model_and_processor,
15
+ parse_bbox_output,
16
+ )
17
+ except ImportError:
18
+ from vlm_bbox_inference import (
19
+ get_model_and_processor,
20
+ parse_bbox_output,
21
+ )
22
+
23
+ PROJECT_ROOT = Path(__file__).resolve().parents[2]
24
+
25
+
26
+ def resolve_default_bbox_model() -> str:
27
+ env_model = os.environ.get("SYNLAYERS_BBOX_MODEL") or os.environ.get("SYNLAYERS_BBOX_MODEL_REPO")
28
+ if env_model:
29
+ return env_model
30
+
31
+ candidates = [
32
+ PROJECT_ROOT if (PROJECT_ROOT / "config.json").exists() and (PROJECT_ROOT / "tokenizer_config.json").exists() else None,
33
+ PROJECT_ROOT / "Bbox-caption-8b",
34
+ Path("/project/llmsvgen/share/data/kmw_layered_checkpoint/Bbox-caption-8b"),
35
+ ]
36
+ for candidate in candidates:
37
+ if candidate and candidate.exists():
38
+ return str(candidate)
39
+ return str(Path("/project/llmsvgen/share/data/kmw_layered_checkpoint/Bbox-caption-8b"))
40
+
41
+
42
+ CAPTION_BBOX_PROMPT_TOP_LEFT = (
43
+ "<image>This image is 1024 pixels in width and 1024 pixels in height. "
44
+ "The coordinate origin is at the top-left corner of the image: x increases to the right, y increases downward. "
45
+ "First describe the whole image in one detailed caption (whole_caption). "
46
+ "Then list the bounding box for each visible layer or object. "
47
+ "Each box is [x_left, y_top, x_right, y_bottom] in pixel coordinates (top-left origin, y downward). "
48
+ "Output a single JSON object with exactly two keys: \"whole_caption\" (string) and \"boxes\" (list of [x_left,y_top,x_right,y_bottom] arrays). "
49
+ "Output only this JSON, no other text or markdown."
50
+ )
51
+
52
+ DEFAULT_BBOX_MODEL = resolve_default_bbox_model()
53
+
54
+ IMAGE_EXTS = {".png", ".jpg", ".jpeg", ".webp", ".bmp"}
55
+
56
+
57
+ def parse_json_caption_bbox(text: str):
58
+ """Parse model output into `(whole_caption, boxes)`."""
59
+ text = (text or "").strip()
60
+
61
+ if "```" in text:
62
+ parts = text.split("```")
63
+ for p in parts:
64
+ p = p.strip()
65
+ if p.startswith("json"):
66
+ p = p[4:].strip()
67
+ if p.startswith("{"):
68
+ try:
69
+ obj = json.loads(p)
70
+ if isinstance(obj, dict):
71
+ caption = obj.get("whole_caption") or obj.get("caption") or ""
72
+ boxes = obj.get("boxes") or obj.get("bboxes") or []
73
+ if isinstance(boxes, list):
74
+ return caption, boxes
75
+ except json.JSONDecodeError:
76
+ pass
77
+
78
+ match = re.search(r"\{[\s\S]*\}", text)
79
+ if match:
80
+ try:
81
+ obj = json.loads(match.group(0))
82
+ if isinstance(obj, dict):
83
+ caption = obj.get("whole_caption") or obj.get("caption") or ""
84
+ boxes = obj.get("boxes") or obj.get("bboxes") or []
85
+ if isinstance(boxes, list):
86
+ return caption, boxes
87
+ except json.JSONDecodeError:
88
+ pass
89
+
90
+ boxes = parse_bbox_output(text)
91
+ return "", boxes
92
+
93
+
94
+ def format_image_record_path(image_path: Path, data_dir: Path) -> str:
95
+ try:
96
+ return image_path.relative_to(data_dir).as_posix()
97
+ except ValueError:
98
+ return image_path.name
99
+
100
+
101
+ def collect_images(data_dir: Path, max_samples: int | None, target_samples: set | None = None):
102
+ """Collect images and keep a relative path for JSONL output."""
103
+ data_dir = Path(data_dir)
104
+ out = []
105
+
106
+ for d in sorted(data_dir.glob("sample_*")):
107
+ if not d.is_dir():
108
+ continue
109
+ if target_samples is not None and d.name not in target_samples:
110
+ continue
111
+ whole = d / "whole_image.png"
112
+ if whole.exists():
113
+ out.append((d.name, whole, format_image_record_path(whole, data_dir)))
114
+ if max_samples and len(out) >= max_samples:
115
+ return out
116
+
117
+ if not out:
118
+ def _sort_key(p: Path):
119
+ parts = p.stem.rsplit("_", 1)
120
+ try:
121
+ return (parts[0], int(parts[-1]))
122
+ except ValueError:
123
+ return (p.stem, 0)
124
+
125
+ all_imgs = [
126
+ p for ext in IMAGE_EXTS
127
+ for p in data_dir.glob(f"*{ext}")
128
+ if p.is_file()
129
+ ]
130
+
131
+ for p in sorted(all_imgs, key=_sort_key):
132
+ if target_samples is not None and p.stem not in target_samples:
133
+ continue
134
+ out.append((p.stem, p, format_image_record_path(p, data_dir)))
135
+ if max_samples and len(out) >= max_samples:
136
+ return out
137
+
138
+ return out
139
+
140
+
141
+ def draw_boxes(image_path: Path, bboxes: list, out_path: Path, color: str = "lime", width: int = 3):
142
+ """Draw bounding boxes on an image."""
143
+ img = Image.open(image_path).convert("RGB")
144
+ draw = ImageDraw.Draw(img)
145
+
146
+ for b in bboxes:
147
+ if len(b) != 4:
148
+ continue
149
+ x0, y0, x1, y1 = float(b[0]), float(b[1]), float(b[2]), float(b[3])
150
+ x0, x1 = min(x0, x1), max(x0, x1)
151
+ y0, y1 = min(y0, y1), max(y0, y1)
152
+ draw.rectangle([x0, y0, x1, y1], outline=color, width=width)
153
+
154
+ out_path.parent.mkdir(parents=True, exist_ok=True)
155
+ img.save(out_path)
156
+
157
+
158
+ def infer_caption_bbox(image_path: str | Path, model, processor, *, prompt: str, max_new_tokens: int = 1024):
159
+ """Run caption + bbox inference for one image."""
160
+ path = Path(image_path)
161
+ if not path.exists():
162
+ return "", []
163
+
164
+ content = [
165
+ {"type": "image", "image": str(path.absolute())},
166
+ {"type": "text", "text": prompt},
167
+ ]
168
+
169
+ messages = [{"role": "user", "content": content}]
170
+
171
+ inputs = processor.apply_chat_template(
172
+ messages,
173
+ tokenize=True,
174
+ add_generation_prompt=True,
175
+ return_dict=True,
176
+ return_tensors="pt",
177
+ )
178
+
179
+ inputs = {k: v.to(model.device) if hasattr(v, "to") else v for k, v in inputs.items()}
180
+ inputs.pop("token_type_ids", None)
181
+
182
+ with torch.no_grad():
183
+ generated = model.generate(
184
+ **inputs,
185
+ max_new_tokens=max_new_tokens,
186
+ do_sample=True,
187
+ temperature=0.1,
188
+ repetition_penalty=1.1,
189
+ pad_token_id=processor.tokenizer.pad_token_id or processor.tokenizer.eos_token_id,
190
+ )
191
+
192
+ input_len = inputs["input_ids"].shape[1]
193
+ output_ids = generated[:, input_len:]
194
+
195
+ output_text = processor.batch_decode(
196
+ output_ids,
197
+ skip_special_tokens=True,
198
+ clean_up_tokenization_spaces=True
199
+ )
200
+
201
+ raw = (output_text[0] or "").strip()
202
+ whole_caption, bboxes = parse_json_caption_bbox(raw)
203
+
204
+ result_boxes = []
205
+ for b in bboxes:
206
+ if isinstance(b, (list, tuple)) and len(b) >= 4:
207
+ result_boxes.append([float(b[0]), float(b[1]), float(b[2]), float(b[3])])
208
+
209
+ return whole_caption, result_boxes
210
+
211
+
212
+ def main():
213
+ import argparse
214
+
215
+ parser = argparse.ArgumentParser(
216
+ description="Caption + bbox inference (top-left origin)"
217
+ )
218
+
219
+ parser.add_argument("--data-dir", type=str, default="testset",
220
+ help="Directory containing sample_* or image files")
221
+ parser.add_argument("--output", type=str, default="outputs/infer/caption_bbox_infer.jsonl",
222
+ help="Output JSONL file")
223
+ parser.add_argument("--model", type=str, default=DEFAULT_BBOX_MODEL,
224
+ help="Model path (merged or LoRA) (default: %(default)s)")
225
+ parser.add_argument("--max-samples", type=int, default=None)
226
+ parser.add_argument("--max-new-tokens", type=int, default=1024)
227
+ parser.add_argument("--samples", type=str, nargs="+",
228
+ help="Specify sample names (e.g. sample_001)")
229
+ parser.add_argument("--vis-dir", type=str, default=None,
230
+ help="Optional directory for visualization")
231
+
232
+ args = parser.parse_args()
233
+
234
+ data_dir = Path(args.data_dir)
235
+ target_samples = set(args.samples) if args.samples else None
236
+
237
+ rows = collect_images(data_dir, args.max_samples, target_samples)
238
+ if not rows:
239
+ print(f"No images found under {data_dir}")
240
+ return
241
+
242
+ print(f"Loading model: {args.model}")
243
+ model, processor = get_model_and_processor(args.model)
244
+
245
+ print(f"Running inference on {len(rows)} samples...")
246
+
247
+ out_path = Path(args.output)
248
+ out_path.parent.mkdir(parents=True, exist_ok=True)
249
+
250
+ vis_dir = Path(args.vis_dir) if args.vis_dir else None
251
+
252
+ with open(out_path, "w", encoding="utf-8") as f:
253
+ for name, image_path, image_record_path in rows:
254
+ print(f" {name}")
255
+
256
+ whole_caption, bboxes = infer_caption_bbox(
257
+ image_path,
258
+ model,
259
+ processor,
260
+ prompt=CAPTION_BBOX_PROMPT_TOP_LEFT,
261
+ max_new_tokens=args.max_new_tokens,
262
+ )
263
+
264
+ num_layers = len(bboxes)
265
+
266
+ record = {
267
+ "sample_or_stem": name,
268
+ "image": image_record_path,
269
+ "whole_caption": whole_caption,
270
+ "bboxes": bboxes,
271
+ "num_layers": num_layers,
272
+ "coord": "top_left",
273
+ }
274
+
275
+ f.write(json.dumps(record, ensure_ascii=False) + "\n")
276
+ f.flush()
277
+
278
+ if vis_dir:
279
+ draw_boxes(Path(image_path), bboxes, vis_dir / f"{name}_vis.png")
280
+
281
+ print(f"Wrote {out_path}")
282
+ if vis_dir:
283
+ print(f"Visualizations saved to {vis_dir}")
284
+
285
+
286
+ if __name__ == "__main__":
287
+ main()