#!/usr/bin/env python3 """Run whole-caption + bbox inference and save portable JSONL results.""" import os import json import re from pathlib import Path import torch from PIL import Image, ImageDraw try: from demo.infer.vlm_bbox_inference import ( get_model_and_processor, parse_bbox_output, ) except ImportError: from vlm_bbox_inference import ( get_model_and_processor, parse_bbox_output, ) PROJECT_ROOT = Path(__file__).resolve().parents[2] def resolve_default_bbox_model() -> str: env_model = os.environ.get("SYNLAYERS_BBOX_MODEL") or os.environ.get("SYNLAYERS_BBOX_MODEL_REPO") if env_model: return env_model candidates = [ PROJECT_ROOT if (PROJECT_ROOT / "config.json").exists() and (PROJECT_ROOT / "tokenizer_config.json").exists() else None, PROJECT_ROOT / "Bbox-caption-8b", ] for candidate in candidates: if candidate and candidate.exists(): return str(candidate) return "SynLayers/Bbox-caption-8b" CAPTION_BBOX_PROMPT_TOP_LEFT = ( "This image is 1024 pixels in width and 1024 pixels in height. " "The coordinate origin is at the top-left corner of the image: x increases to the right, y increases downward. " "First describe the whole image in one detailed caption (whole_caption). " "Then list the bounding box for each visible layer or object. " "Each box is [x_left, y_top, x_right, y_bottom] in pixel coordinates (top-left origin, y downward). " "Output a single JSON object with exactly two keys: \"whole_caption\" (string) and \"boxes\" (list of [x_left,y_top,x_right,y_bottom] arrays). " "Output only this JSON, no other text or markdown." ) DEFAULT_BBOX_MODEL = resolve_default_bbox_model() IMAGE_EXTS = {".png", ".jpg", ".jpeg", ".webp", ".bmp"} def parse_json_caption_bbox(text: str): """Parse model output into `(whole_caption, boxes)`.""" text = (text or "").strip() if "```" in text: parts = text.split("```") for p in parts: p = p.strip() if p.startswith("json"): p = p[4:].strip() if p.startswith("{"): try: obj = json.loads(p) if isinstance(obj, dict): caption = obj.get("whole_caption") or obj.get("caption") or "" boxes = obj.get("boxes") or obj.get("bboxes") or [] if isinstance(boxes, list): return caption, boxes except json.JSONDecodeError: pass match = re.search(r"\{[\s\S]*\}", text) if match: try: obj = json.loads(match.group(0)) if isinstance(obj, dict): caption = obj.get("whole_caption") or obj.get("caption") or "" boxes = obj.get("boxes") or obj.get("bboxes") or [] if isinstance(boxes, list): return caption, boxes except json.JSONDecodeError: pass boxes = parse_bbox_output(text) return "", boxes def format_image_record_path(image_path: Path, data_dir: Path) -> str: try: return image_path.relative_to(data_dir).as_posix() except ValueError: return image_path.name def collect_images(data_dir: Path, max_samples: int | None, target_samples: set | None = None): """Collect images and keep a relative path for JSONL output.""" data_dir = Path(data_dir) out = [] for d in sorted(data_dir.glob("sample_*")): if not d.is_dir(): continue if target_samples is not None and d.name not in target_samples: continue whole = d / "whole_image.png" if whole.exists(): out.append((d.name, whole, format_image_record_path(whole, data_dir))) if max_samples and len(out) >= max_samples: return out if not out: def _sort_key(p: Path): parts = p.stem.rsplit("_", 1) try: return (parts[0], int(parts[-1])) except ValueError: return (p.stem, 0) all_imgs = [ p for ext in IMAGE_EXTS for p in data_dir.glob(f"*{ext}") if p.is_file() ] for p in sorted(all_imgs, key=_sort_key): if target_samples is not None and p.stem not in target_samples: continue out.append((p.stem, p, format_image_record_path(p, data_dir))) if max_samples and len(out) >= max_samples: return out return out def draw_boxes(image_path: Path, bboxes: list, out_path: Path, color: str = "lime", width: int = 3): """Draw bounding boxes on an image.""" img = Image.open(image_path).convert("RGB") draw = ImageDraw.Draw(img) for b in bboxes: if len(b) != 4: continue x0, y0, x1, y1 = float(b[0]), float(b[1]), float(b[2]), float(b[3]) x0, x1 = min(x0, x1), max(x0, x1) y0, y1 = min(y0, y1), max(y0, y1) draw.rectangle([x0, y0, x1, y1], outline=color, width=width) out_path.parent.mkdir(parents=True, exist_ok=True) img.save(out_path) def infer_caption_bbox(image_path: str | Path, model, processor, *, prompt: str, max_new_tokens: int = 1024): """Run caption + bbox inference for one image.""" path = Path(image_path) if not path.exists(): return "", [] content = [ {"type": "image", "image": str(path.absolute())}, {"type": "text", "text": prompt}, ] messages = [{"role": "user", "content": content}] inputs = processor.apply_chat_template( messages, tokenize=True, add_generation_prompt=True, return_dict=True, return_tensors="pt", ) inputs = {k: v.to(model.device) if hasattr(v, "to") else v for k, v in inputs.items()} inputs.pop("token_type_ids", None) with torch.no_grad(): generated = model.generate( **inputs, max_new_tokens=max_new_tokens, do_sample=True, temperature=0.1, repetition_penalty=1.1, pad_token_id=processor.tokenizer.pad_token_id or processor.tokenizer.eos_token_id, ) input_len = inputs["input_ids"].shape[1] output_ids = generated[:, input_len:] output_text = processor.batch_decode( output_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True ) raw = (output_text[0] or "").strip() whole_caption, bboxes = parse_json_caption_bbox(raw) result_boxes = [] for b in bboxes: if isinstance(b, (list, tuple)) and len(b) >= 4: result_boxes.append([float(b[0]), float(b[1]), float(b[2]), float(b[3])]) return whole_caption, result_boxes def main(): import argparse parser = argparse.ArgumentParser( description="Caption + bbox inference (top-left origin)" ) parser.add_argument("--data-dir", type=str, default="testset", help="Directory containing sample_* or image files") parser.add_argument("--output", type=str, default="outputs/infer/caption_bbox_infer.jsonl", help="Output JSONL file") parser.add_argument("--model", type=str, default=DEFAULT_BBOX_MODEL, help="Model path (merged or LoRA) (default: %(default)s)") parser.add_argument("--max-samples", type=int, default=None) parser.add_argument("--max-new-tokens", type=int, default=1024) parser.add_argument("--samples", type=str, nargs="+", help="Specify sample names (e.g. sample_001)") parser.add_argument("--vis-dir", type=str, default=None, help="Optional directory for visualization") args = parser.parse_args() data_dir = Path(args.data_dir) target_samples = set(args.samples) if args.samples else None rows = collect_images(data_dir, args.max_samples, target_samples) if not rows: print(f"No images found under {data_dir}") return print(f"Loading model: {args.model}") model, processor = get_model_and_processor(args.model) print(f"Running inference on {len(rows)} samples...") out_path = Path(args.output) out_path.parent.mkdir(parents=True, exist_ok=True) vis_dir = Path(args.vis_dir) if args.vis_dir else None with open(out_path, "w", encoding="utf-8") as f: for name, image_path, image_record_path in rows: print(f" {name}") whole_caption, bboxes = infer_caption_bbox( image_path, model, processor, prompt=CAPTION_BBOX_PROMPT_TOP_LEFT, max_new_tokens=args.max_new_tokens, ) num_layers = len(bboxes) record = { "sample_or_stem": name, "image": image_record_path, "whole_caption": whole_caption, "bboxes": bboxes, "num_layers": num_layers, "coord": "top_left", } f.write(json.dumps(record, ensure_ascii=False) + "\n") f.flush() if vis_dir: draw_boxes(Path(image_path), bboxes, vis_dir / f"{name}_vis.png") print(f"Wrote {out_path}") if vis_dir: print(f"Visualizations saved to {vis_dir}") if __name__ == "__main__": main()