SynLayers commited on
Commit
31de91e
·
verified ·
1 Parent(s): e68d78e

Upload demo/infer/run_caption_bbox_infer.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. demo/infer/run_caption_bbox_infer.py +285 -0
demo/infer/run_caption_bbox_infer.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Run whole-caption + bbox inference and save portable JSONL results."""
3
+
4
+ import os
5
+ import json
6
+ import re
7
+ from pathlib import Path
8
+
9
+ import torch
10
+ from PIL import Image, ImageDraw
11
+
12
+ try:
13
+ from demo.infer.vlm_bbox_inference import (
14
+ get_model_and_processor,
15
+ parse_bbox_output,
16
+ )
17
+ except ImportError:
18
+ from vlm_bbox_inference import (
19
+ get_model_and_processor,
20
+ parse_bbox_output,
21
+ )
22
+
23
+ PROJECT_ROOT = Path(__file__).resolve().parents[2]
24
+
25
+
26
+ def resolve_default_bbox_model() -> str:
27
+ env_path = os.environ.get("SYNLAYERS_BBOX_MODEL")
28
+ candidates = [
29
+ Path(env_path) if env_path else None,
30
+ PROJECT_ROOT if (PROJECT_ROOT / "config.json").exists() and (PROJECT_ROOT / "tokenizer_config.json").exists() else None,
31
+ PROJECT_ROOT / "Bbox-caption-8b",
32
+ Path("/project/llmsvgen/share/data/kmw_layered_checkpoint/Bbox-caption-8b"),
33
+ ]
34
+ for candidate in candidates:
35
+ if candidate and candidate.exists():
36
+ return str(candidate)
37
+ return str(Path("/project/llmsvgen/share/data/kmw_layered_checkpoint/Bbox-caption-8b"))
38
+
39
+
40
+ CAPTION_BBOX_PROMPT_TOP_LEFT = (
41
+ "<image>This image is 1024 pixels in width and 1024 pixels in height. "
42
+ "The coordinate origin is at the top-left corner of the image: x increases to the right, y increases downward. "
43
+ "First describe the whole image in one detailed caption (whole_caption). "
44
+ "Then list the bounding box for each visible layer or object. "
45
+ "Each box is [x_left, y_top, x_right, y_bottom] in pixel coordinates (top-left origin, y downward). "
46
+ "Output a single JSON object with exactly two keys: \"whole_caption\" (string) and \"boxes\" (list of [x_left,y_top,x_right,y_bottom] arrays). "
47
+ "Output only this JSON, no other text or markdown."
48
+ )
49
+
50
+ DEFAULT_BBOX_MODEL = resolve_default_bbox_model()
51
+
52
+ IMAGE_EXTS = {".png", ".jpg", ".jpeg", ".webp", ".bmp"}
53
+
54
+
55
+ def parse_json_caption_bbox(text: str):
56
+ """Parse model output into `(whole_caption, boxes)`."""
57
+ text = (text or "").strip()
58
+
59
+ if "```" in text:
60
+ parts = text.split("```")
61
+ for p in parts:
62
+ p = p.strip()
63
+ if p.startswith("json"):
64
+ p = p[4:].strip()
65
+ if p.startswith("{"):
66
+ try:
67
+ obj = json.loads(p)
68
+ if isinstance(obj, dict):
69
+ caption = obj.get("whole_caption") or obj.get("caption") or ""
70
+ boxes = obj.get("boxes") or obj.get("bboxes") or []
71
+ if isinstance(boxes, list):
72
+ return caption, boxes
73
+ except json.JSONDecodeError:
74
+ pass
75
+
76
+ match = re.search(r"\{[\s\S]*\}", text)
77
+ if match:
78
+ try:
79
+ obj = json.loads(match.group(0))
80
+ if isinstance(obj, dict):
81
+ caption = obj.get("whole_caption") or obj.get("caption") or ""
82
+ boxes = obj.get("boxes") or obj.get("bboxes") or []
83
+ if isinstance(boxes, list):
84
+ return caption, boxes
85
+ except json.JSONDecodeError:
86
+ pass
87
+
88
+ boxes = parse_bbox_output(text)
89
+ return "", boxes
90
+
91
+
92
+ def format_image_record_path(image_path: Path, data_dir: Path) -> str:
93
+ try:
94
+ return image_path.relative_to(data_dir).as_posix()
95
+ except ValueError:
96
+ return image_path.name
97
+
98
+
99
+ def collect_images(data_dir: Path, max_samples: int | None, target_samples: set | None = None):
100
+ """Collect images and keep a relative path for JSONL output."""
101
+ data_dir = Path(data_dir)
102
+ out = []
103
+
104
+ for d in sorted(data_dir.glob("sample_*")):
105
+ if not d.is_dir():
106
+ continue
107
+ if target_samples is not None and d.name not in target_samples:
108
+ continue
109
+ whole = d / "whole_image.png"
110
+ if whole.exists():
111
+ out.append((d.name, whole, format_image_record_path(whole, data_dir)))
112
+ if max_samples and len(out) >= max_samples:
113
+ return out
114
+
115
+ if not out:
116
+ def _sort_key(p: Path):
117
+ parts = p.stem.rsplit("_", 1)
118
+ try:
119
+ return (parts[0], int(parts[-1]))
120
+ except ValueError:
121
+ return (p.stem, 0)
122
+
123
+ all_imgs = [
124
+ p for ext in IMAGE_EXTS
125
+ for p in data_dir.glob(f"*{ext}")
126
+ if p.is_file()
127
+ ]
128
+
129
+ for p in sorted(all_imgs, key=_sort_key):
130
+ if target_samples is not None and p.stem not in target_samples:
131
+ continue
132
+ out.append((p.stem, p, format_image_record_path(p, data_dir)))
133
+ if max_samples and len(out) >= max_samples:
134
+ return out
135
+
136
+ return out
137
+
138
+
139
+ def draw_boxes(image_path: Path, bboxes: list, out_path: Path, color: str = "lime", width: int = 3):
140
+ """Draw bounding boxes on an image."""
141
+ img = Image.open(image_path).convert("RGB")
142
+ draw = ImageDraw.Draw(img)
143
+
144
+ for b in bboxes:
145
+ if len(b) != 4:
146
+ continue
147
+ x0, y0, x1, y1 = float(b[0]), float(b[1]), float(b[2]), float(b[3])
148
+ x0, x1 = min(x0, x1), max(x0, x1)
149
+ y0, y1 = min(y0, y1), max(y0, y1)
150
+ draw.rectangle([x0, y0, x1, y1], outline=color, width=width)
151
+
152
+ out_path.parent.mkdir(parents=True, exist_ok=True)
153
+ img.save(out_path)
154
+
155
+
156
+ def infer_caption_bbox(image_path: str | Path, model, processor, *, prompt: str, max_new_tokens: int = 1024):
157
+ """Run caption + bbox inference for one image."""
158
+ path = Path(image_path)
159
+ if not path.exists():
160
+ return "", []
161
+
162
+ content = [
163
+ {"type": "image", "image": str(path.absolute())},
164
+ {"type": "text", "text": prompt},
165
+ ]
166
+
167
+ messages = [{"role": "user", "content": content}]
168
+
169
+ inputs = processor.apply_chat_template(
170
+ messages,
171
+ tokenize=True,
172
+ add_generation_prompt=True,
173
+ return_dict=True,
174
+ return_tensors="pt",
175
+ )
176
+
177
+ inputs = {k: v.to(model.device) if hasattr(v, "to") else v for k, v in inputs.items()}
178
+ inputs.pop("token_type_ids", None)
179
+
180
+ with torch.no_grad():
181
+ generated = model.generate(
182
+ **inputs,
183
+ max_new_tokens=max_new_tokens,
184
+ do_sample=True,
185
+ temperature=0.1,
186
+ repetition_penalty=1.1,
187
+ pad_token_id=processor.tokenizer.pad_token_id or processor.tokenizer.eos_token_id,
188
+ )
189
+
190
+ input_len = inputs["input_ids"].shape[1]
191
+ output_ids = generated[:, input_len:]
192
+
193
+ output_text = processor.batch_decode(
194
+ output_ids,
195
+ skip_special_tokens=True,
196
+ clean_up_tokenization_spaces=True
197
+ )
198
+
199
+ raw = (output_text[0] or "").strip()
200
+ whole_caption, bboxes = parse_json_caption_bbox(raw)
201
+
202
+ result_boxes = []
203
+ for b in bboxes:
204
+ if isinstance(b, (list, tuple)) and len(b) >= 4:
205
+ result_boxes.append([float(b[0]), float(b[1]), float(b[2]), float(b[3])])
206
+
207
+ return whole_caption, result_boxes
208
+
209
+
210
+ def main():
211
+ import argparse
212
+
213
+ parser = argparse.ArgumentParser(
214
+ description="Caption + bbox inference (top-left origin)"
215
+ )
216
+
217
+ parser.add_argument("--data-dir", type=str, default="testset",
218
+ help="Directory containing sample_* or image files")
219
+ parser.add_argument("--output", type=str, default="outputs/infer/caption_bbox_infer.jsonl",
220
+ help="Output JSONL file")
221
+ parser.add_argument("--model", type=str, default=DEFAULT_BBOX_MODEL,
222
+ help="Model path (merged or LoRA) (default: %(default)s)")
223
+ parser.add_argument("--max-samples", type=int, default=None)
224
+ parser.add_argument("--max-new-tokens", type=int, default=1024)
225
+ parser.add_argument("--samples", type=str, nargs="+",
226
+ help="Specify sample names (e.g. sample_001)")
227
+ parser.add_argument("--vis-dir", type=str, default=None,
228
+ help="Optional directory for visualization")
229
+
230
+ args = parser.parse_args()
231
+
232
+ data_dir = Path(args.data_dir)
233
+ target_samples = set(args.samples) if args.samples else None
234
+
235
+ rows = collect_images(data_dir, args.max_samples, target_samples)
236
+ if not rows:
237
+ print(f"No images found under {data_dir}")
238
+ return
239
+
240
+ print(f"Loading model: {args.model}")
241
+ model, processor = get_model_and_processor(args.model)
242
+
243
+ print(f"Running inference on {len(rows)} samples...")
244
+
245
+ out_path = Path(args.output)
246
+ out_path.parent.mkdir(parents=True, exist_ok=True)
247
+
248
+ vis_dir = Path(args.vis_dir) if args.vis_dir else None
249
+
250
+ with open(out_path, "w", encoding="utf-8") as f:
251
+ for name, image_path, image_record_path in rows:
252
+ print(f" {name}")
253
+
254
+ whole_caption, bboxes = infer_caption_bbox(
255
+ image_path,
256
+ model,
257
+ processor,
258
+ prompt=CAPTION_BBOX_PROMPT_TOP_LEFT,
259
+ max_new_tokens=args.max_new_tokens,
260
+ )
261
+
262
+ num_layers = len(bboxes)
263
+
264
+ record = {
265
+ "sample_or_stem": name,
266
+ "image": image_record_path,
267
+ "whole_caption": whole_caption,
268
+ "bboxes": bboxes,
269
+ "num_layers": num_layers,
270
+ "coord": "top_left",
271
+ }
272
+
273
+ f.write(json.dumps(record, ensure_ascii=False) + "\n")
274
+ f.flush()
275
+
276
+ if vis_dir:
277
+ draw_boxes(Path(image_path), bboxes, vis_dir / f"{name}_vis.png")
278
+
279
+ print(f"Wrote {out_path}")
280
+ if vis_dir:
281
+ print(f"Visualizations saved to {vis_dir}")
282
+
283
+
284
+ if __name__ == "__main__":
285
+ main()