synlayers / demo /infer /run_caption_bbox_infer.py
SynLayers's picture
Upload demo/infer/run_caption_bbox_infer.py with huggingface_hub
43bd07d verified
#!/usr/bin/env python3
"""Run whole-caption + bbox inference and save portable JSONL results."""
import os
import json
import re
from pathlib import Path
import torch
from PIL import Image, ImageDraw
try:
from demo.infer.vlm_bbox_inference import (
get_model_and_processor,
parse_bbox_output,
)
except ImportError:
from vlm_bbox_inference import (
get_model_and_processor,
parse_bbox_output,
)
PROJECT_ROOT = Path(__file__).resolve().parents[2]
def resolve_default_bbox_model() -> str:
env_model = os.environ.get("SYNLAYERS_BBOX_MODEL") or os.environ.get("SYNLAYERS_BBOX_MODEL_REPO")
if env_model:
return env_model
candidates = [
PROJECT_ROOT if (PROJECT_ROOT / "config.json").exists() and (PROJECT_ROOT / "tokenizer_config.json").exists() else None,
PROJECT_ROOT / "Bbox-caption-8b",
]
for candidate in candidates:
if candidate and candidate.exists():
return str(candidate)
return "SynLayers/Bbox-caption-8b"
CAPTION_BBOX_PROMPT_TOP_LEFT = (
"<image>This image is 1024 pixels in width and 1024 pixels in height. "
"The coordinate origin is at the top-left corner of the image: x increases to the right, y increases downward. "
"First describe the whole image in one detailed caption (whole_caption). "
"Then list the bounding box for each visible layer or object. "
"Each box is [x_left, y_top, x_right, y_bottom] in pixel coordinates (top-left origin, y downward). "
"Output a single JSON object with exactly two keys: \"whole_caption\" (string) and \"boxes\" (list of [x_left,y_top,x_right,y_bottom] arrays). "
"Output only this JSON, no other text or markdown."
)
DEFAULT_BBOX_MODEL = resolve_default_bbox_model()
IMAGE_EXTS = {".png", ".jpg", ".jpeg", ".webp", ".bmp"}
def parse_json_caption_bbox(text: str):
"""Parse model output into `(whole_caption, boxes)`."""
text = (text or "").strip()
if "```" in text:
parts = text.split("```")
for p in parts:
p = p.strip()
if p.startswith("json"):
p = p[4:].strip()
if p.startswith("{"):
try:
obj = json.loads(p)
if isinstance(obj, dict):
caption = obj.get("whole_caption") or obj.get("caption") or ""
boxes = obj.get("boxes") or obj.get("bboxes") or []
if isinstance(boxes, list):
return caption, boxes
except json.JSONDecodeError:
pass
match = re.search(r"\{[\s\S]*\}", text)
if match:
try:
obj = json.loads(match.group(0))
if isinstance(obj, dict):
caption = obj.get("whole_caption") or obj.get("caption") or ""
boxes = obj.get("boxes") or obj.get("bboxes") or []
if isinstance(boxes, list):
return caption, boxes
except json.JSONDecodeError:
pass
boxes = parse_bbox_output(text)
return "", boxes
def format_image_record_path(image_path: Path, data_dir: Path) -> str:
try:
return image_path.relative_to(data_dir).as_posix()
except ValueError:
return image_path.name
def collect_images(data_dir: Path, max_samples: int | None, target_samples: set | None = None):
"""Collect images and keep a relative path for JSONL output."""
data_dir = Path(data_dir)
out = []
for d in sorted(data_dir.glob("sample_*")):
if not d.is_dir():
continue
if target_samples is not None and d.name not in target_samples:
continue
whole = d / "whole_image.png"
if whole.exists():
out.append((d.name, whole, format_image_record_path(whole, data_dir)))
if max_samples and len(out) >= max_samples:
return out
if not out:
def _sort_key(p: Path):
parts = p.stem.rsplit("_", 1)
try:
return (parts[0], int(parts[-1]))
except ValueError:
return (p.stem, 0)
all_imgs = [
p for ext in IMAGE_EXTS
for p in data_dir.glob(f"*{ext}")
if p.is_file()
]
for p in sorted(all_imgs, key=_sort_key):
if target_samples is not None and p.stem not in target_samples:
continue
out.append((p.stem, p, format_image_record_path(p, data_dir)))
if max_samples and len(out) >= max_samples:
return out
return out
def draw_boxes(image_path: Path, bboxes: list, out_path: Path, color: str = "lime", width: int = 3):
"""Draw bounding boxes on an image."""
img = Image.open(image_path).convert("RGB")
draw = ImageDraw.Draw(img)
for b in bboxes:
if len(b) != 4:
continue
x0, y0, x1, y1 = float(b[0]), float(b[1]), float(b[2]), float(b[3])
x0, x1 = min(x0, x1), max(x0, x1)
y0, y1 = min(y0, y1), max(y0, y1)
draw.rectangle([x0, y0, x1, y1], outline=color, width=width)
out_path.parent.mkdir(parents=True, exist_ok=True)
img.save(out_path)
def infer_caption_bbox(image_path: str | Path, model, processor, *, prompt: str, max_new_tokens: int = 1024):
"""Run caption + bbox inference for one image."""
path = Path(image_path)
if not path.exists():
return "", []
content = [
{"type": "image", "image": str(path.absolute())},
{"type": "text", "text": prompt},
]
messages = [{"role": "user", "content": content}]
inputs = processor.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_dict=True,
return_tensors="pt",
)
inputs = {k: v.to(model.device) if hasattr(v, "to") else v for k, v in inputs.items()}
inputs.pop("token_type_ids", None)
with torch.no_grad():
generated = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=0.1,
repetition_penalty=1.1,
pad_token_id=processor.tokenizer.pad_token_id or processor.tokenizer.eos_token_id,
)
input_len = inputs["input_ids"].shape[1]
output_ids = generated[:, input_len:]
output_text = processor.batch_decode(
output_ids,
skip_special_tokens=True,
clean_up_tokenization_spaces=True
)
raw = (output_text[0] or "").strip()
whole_caption, bboxes = parse_json_caption_bbox(raw)
result_boxes = []
for b in bboxes:
if isinstance(b, (list, tuple)) and len(b) >= 4:
result_boxes.append([float(b[0]), float(b[1]), float(b[2]), float(b[3])])
return whole_caption, result_boxes
def main():
import argparse
parser = argparse.ArgumentParser(
description="Caption + bbox inference (top-left origin)"
)
parser.add_argument("--data-dir", type=str, default="testset",
help="Directory containing sample_* or image files")
parser.add_argument("--output", type=str, default="outputs/infer/caption_bbox_infer.jsonl",
help="Output JSONL file")
parser.add_argument("--model", type=str, default=DEFAULT_BBOX_MODEL,
help="Model path (merged or LoRA) (default: %(default)s)")
parser.add_argument("--max-samples", type=int, default=None)
parser.add_argument("--max-new-tokens", type=int, default=1024)
parser.add_argument("--samples", type=str, nargs="+",
help="Specify sample names (e.g. sample_001)")
parser.add_argument("--vis-dir", type=str, default=None,
help="Optional directory for visualization")
args = parser.parse_args()
data_dir = Path(args.data_dir)
target_samples = set(args.samples) if args.samples else None
rows = collect_images(data_dir, args.max_samples, target_samples)
if not rows:
print(f"No images found under {data_dir}")
return
print(f"Loading model: {args.model}")
model, processor = get_model_and_processor(args.model)
print(f"Running inference on {len(rows)} samples...")
out_path = Path(args.output)
out_path.parent.mkdir(parents=True, exist_ok=True)
vis_dir = Path(args.vis_dir) if args.vis_dir else None
with open(out_path, "w", encoding="utf-8") as f:
for name, image_path, image_record_path in rows:
print(f" {name}")
whole_caption, bboxes = infer_caption_bbox(
image_path,
model,
processor,
prompt=CAPTION_BBOX_PROMPT_TOP_LEFT,
max_new_tokens=args.max_new_tokens,
)
num_layers = len(bboxes)
record = {
"sample_or_stem": name,
"image": image_record_path,
"whole_caption": whole_caption,
"bboxes": bboxes,
"num_layers": num_layers,
"coord": "top_left",
}
f.write(json.dumps(record, ensure_ascii=False) + "\n")
f.flush()
if vis_dir:
draw_boxes(Path(image_path), bboxes, vis_dir / f"{name}_vis.png")
print(f"Wrote {out_path}")
if vis_dir:
print(f"Visualizations saved to {vis_dir}")
if __name__ == "__main__":
main()