cai-qi's picture
Update Space
c6d3f05
import os
import math
import torch
from typing import Optional
from PIL import Image, ImageDraw
import json
from typing import Any, Dict, Iterable, List, Sequence, Tuple
MAX_BOX = 5
PREDEFINED_RESOLUTIONS = [
(2048, 2048),
(2304, 1728),
(1728, 2304),
(2560, 1440),
(1440, 2560),
(2496, 1664),
(1664, 2496),
(3104, 1312),
(1312, 3104),
(2304, 1792),
(1792, 2304),
]
DEFAULT_COLORS = [
(255, 0, 0),
(0, 180, 0),
(0, 0, 255),
(204, 180, 0),
(255, 0, 255),
(0, 255, 255),
(128, 0, 0),
(0, 128, 0),
(0, 0, 128),
(128, 128, 0),
]
def load_layout_bboxes(layout_bboxes: str) -> Any:
"""Load layout boxes from either a JSON string or a JSON file path."""
if os.path.exists(layout_bboxes):
with open(layout_bboxes, "r", encoding="utf-8") as f:
return json.load(f)
return json.loads(layout_bboxes)
def _unwrap_boxes(data: Any) -> Any:
if isinstance(data, dict):
for key in ("layout_bboxes", "bboxes", "boxes", "bbox_list"):
if key in data:
return data[key]
return data
def _as_bbox_and_text(item: Any) -> Tuple[Sequence[float], str]:
if isinstance(item, dict):
bbox = item.get("bbox") or item.get("box")
text = str(item.get("text") or item.get("label") or "")
if bbox is None:
raise ValueError(f"Missing bbox in layout item: {item!r}")
return bbox, text
if isinstance(item, (list, tuple)) and len(item) == 4:
return item, ""
raise ValueError(f"Unsupported layout bbox item: {item!r}")
def _xxyy_relative_to_absolute_bbox(bbox: Sequence[float], width: int, height: int) -> List[int]:
if len(bbox) != 4:
raise ValueError(f"Expected bbox with 4 values, got: {bbox!r}")
x1, x2, y1, y2 = [float(v) for v in bbox]
# Inference layout input is xxyy relative coordinates: [x1, x2, y1, y2].
# Values in [0, 1] are the intended format. Keep 0-100 support for convenience.
max_abs = max(abs(x1), abs(y1), abs(x2), abs(y2))
if max_abs <= 1.0:
x1, x2 = x1 * width, x2 * width
y1, y2 = y1 * height, y2 * height
elif max_abs <= 100.0:
x1, x2 = x1 / 100.0 * width, x2 / 100.0 * width
y1, y2 = y1 / 100.0 * height, y2 / 100.0 * height
x1, x2 = sorted((x1, x2))
y1, y2 = sorted((y1, y2))
x1 = max(0, min(width - 1, int(round(x1))))
y1 = max(0, min(height - 1, int(round(y1))))
x2 = max(0, min(width - 1, int(round(x2))))
y2 = max(0, min(height - 1, int(round(y2))))
if x2 <= x1 or y2 <= y1:
raise ValueError(f"Invalid bbox after scaling/clamping: {[x1, y1, x2, y2]!r}")
return [x1, y1, x2, y2]
def parse_layout_bboxes(layout_bboxes: Any, width: int, height: int) -> List[Dict[str, Any]]:
"""Convert xxyy relative layout boxes into the training-side bbox layout format."""
raw_boxes = _unwrap_boxes(layout_bboxes)
if not isinstance(raw_boxes, list):
raise ValueError("layout_bboxes must be a list, or a dict containing one of: layout_bboxes/bboxes/boxes")
parsed = []
for idx, item in enumerate(raw_boxes):
bbox, text = _as_bbox_and_text(item)
parsed.append({
"bbox": _xxyy_relative_to_absolute_bbox(bbox, width, height),
"color": "",
"text": text,
"image": None,
"_orig_idx": idx,
})
return parsed
def _bbox_area(item: Dict[str, Any]) -> int:
x1, y1, x2, y2 = item["bbox"]
return max(0, x2 - x1) * max(0, y2 - y1)
def get_render_params(image_width: int, image_height: int) -> Tuple[int, int]:
edge = math.sqrt(image_width * image_height)
max_font_size = int(edge * 0.07)
max_bbox_line_width = int(edge * 0.05)
return max_font_size, max_bbox_line_width
def draw_bbox_layout(
bbox_list: List[Dict[str, Any]],
image_width: int,
image_height: int,
max_bbox: int = MAX_BOX,
max_bbox_line_width: int | None = None,
bbox_line_gap: int | None = None,
return_color: bool = False,
):
"""Draw a black layout image with colored boxes, matching the training-side layout style."""
if max_bbox_line_width is None:
_, max_bbox_line_width = get_render_params(image_width, image_height)
if bbox_line_gap is None:
bbox_line_gap = max(1, max_bbox_line_width // max_bbox)
image = Image.new("RGB", (image_width, image_height), (0, 0, 0))
draw = ImageDraw.Draw(image)
color_list = [None] * len(bbox_list)
sorted_bboxes = sorted(bbox_list, key=_bbox_area, reverse=True)[:max_bbox]
for sorted_idx, item in enumerate(sorted_bboxes):
color = DEFAULT_COLORS[sorted_idx % len(DEFAULT_COLORS)]
orig_idx = int(item.get("_orig_idx", sorted_idx))
if 0 <= orig_idx < len(color_list):
color_list[orig_idx] = color
line_width = max(max_bbox_line_width - sorted_idx * bbox_line_gap, 5)
draw.rectangle([int(v) for v in item["bbox"]], outline=color, width=line_width)
if return_color:
return image, color_list
return image
def add_outer_border_keep_size(pil: Image.Image, color: Iterable[int], width: int) -> Image.Image:
"""Draw a border inside the image without changing its size."""
img = pil.convert("RGB").copy()
color_tuple = tuple(int(c) for c in color)
width = max(0, int(width))
if width == 0:
return img
draw = ImageDraw.Draw(img)
w, h = img.size
for t in range(width):
draw.rectangle([t, t, w - 1 - t, h - 1 - t], outline=color_tuple)
return img
def create_layout_reference_images(
ref_pils: Sequence[str],
layout_bboxes: Any,
image_width: int,
image_height: int,
ref_max_size: int | None = None,
patch_size: int = 32,
) -> Tuple[List[str], str]:
"""Create bordered ref images plus one layout image; returns paths to pass as ref_images."""
parsed_boxes = parse_layout_bboxes(layout_bboxes, image_width, image_height)
layout_image, color_list = draw_bbox_layout(
parsed_boxes,
image_width=image_width,
image_height=image_height,
return_color=True,
)
output_refs: List[str] = []
for idx, ref in enumerate(ref_pils):
if ref_max_size is not None:
ref = resize_pilimage(ref, ref_max_size, patch_size)
color = color_list[idx] if idx < len(color_list) and color_list[idx] is not None else DEFAULT_COLORS[idx % len(DEFAULT_COLORS)]
line_width = int(math.sqrt(ref.width * ref.height) * 0.04)
bordered = add_outer_border_keep_size(ref, color, line_width)
output_refs.append(bordered)
output_refs.append(layout_image)
return output_refs
def find_closest_resolution(width, height):
img_ratio = width / height
best_res = None
min_diff = float("inf")
for w, h in PREDEFINED_RESOLUTIONS:
ratio = w / h
diff = abs(ratio - img_ratio)
if diff < min_diff:
min_diff = diff
best_res = (w, h)
return best_res
def resize_pilimage(pil_image, image_size, patch_size=16, resampler=Image.BICUBIC):
while min(*pil_image.size) >= 2 * image_size:
pil_image = pil_image.resize(
tuple(x // 2 for x in pil_image.size), resample=Image.BOX
)
m = patch_size
width, height = pil_image.width, pil_image.height
S_max = image_size * image_size
scale = S_max / (width * height)
scale = math.sqrt(scale)
new_sizes = [
(round(width * scale) // m * m, round(height * scale) // m * m),
(round(width * scale) // m * m, math.floor(height * scale) // m * m),
(math.floor(width * scale) // m * m, round(height * scale) // m * m),
(math.floor(width * scale) // m * m, math.floor(height * scale) // m * m),
]
new_sizes = sorted(new_sizes, key=lambda x: x[0] * x[1], reverse=True)
for new_size in new_sizes:
if new_size[0] * new_size[1] <= S_max:
break
s1 = width / new_size[0]
s2 = height / new_size[1]
if s1 < s2:
pil_image = pil_image.resize([new_size[0], round(height / s1)], resample=resampler)
top = (round(height / s1) - new_size[1]) // 2
pil_image = pil_image.crop((0, top, new_size[0], top + new_size[1]))
else:
pil_image = pil_image.resize([round(width / s2), new_size[1]], resample=resampler)
left = (round(width / s2) - new_size[0]) // 2
pil_image = pil_image.crop((left, 0, left + new_size[0], new_size[1]))
return pil_image
def calculate_dimensions(max_size, ratio):
width = math.sqrt(max_size * max_size * ratio)
height = width / ratio
width = int(width / 32) * 32
height = int(height / 32) * 32
return width, height
def get_rope_index_fix_point(
spatial_merge_size,
image_token_id,
video_token_id,
vision_start_token_id,
input_ids: Optional[torch.LongTensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
video_grid_thw: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
skip_vision_start_token=None,
fix_point=4096,
) -> tuple[torch.Tensor, torch.Tensor]:
if video_grid_thw is not None:
video_grid_thw = torch.repeat_interleave(video_grid_thw, video_grid_thw[:, 0], dim=0)
video_grid_thw[:, 0] = 1
mrope_position_deltas = []
if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None):
total_input_ids = input_ids
if attention_mask is None:
attention_mask = torch.ones_like(total_input_ids)
position_ids = torch.ones(
3,
input_ids.shape[0],
input_ids.shape[1],
dtype=input_ids.dtype,
device=input_ids.device,
)
image_index, video_index = 0, 0
attention_mask = attention_mask.to(total_input_ids.device)
for i, input_ids in enumerate(total_input_ids):
input_ids = input_ids[attention_mask[i] == 1]
image_nums, video_nums = 0, 0
vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1)
vision_tokens = input_ids[vision_start_indices + 1]
image_nums = (vision_tokens == image_token_id).sum()
video_nums = (vision_tokens == video_token_id).sum()
input_tokens = input_ids.tolist()
llm_pos_ids_list: list = []
st = 0
remain_images, remain_videos = image_nums, video_nums
for _ in range(image_nums + video_nums):
if image_token_id in input_tokens and remain_images > 0:
ed_image = input_tokens.index(image_token_id, st)
else:
ed_image = len(input_tokens) + 1
if video_token_id in input_tokens and remain_videos > 0:
ed_video = input_tokens.index(video_token_id, st)
else:
ed_video = len(input_tokens) + 1
if ed_image < ed_video:
t, h, w = (
image_grid_thw[image_index][0],
image_grid_thw[image_index][1],
image_grid_thw[image_index][2],
)
image_index += 1
remain_images -= 1
ed = ed_image
else:
t, h, w = (
video_grid_thw[video_index][0],
video_grid_thw[video_index][1],
video_grid_thw[video_index][2],
)
video_index += 1
remain_videos -= 1
ed = ed_video
llm_grid_t, llm_grid_h, llm_grid_w = (
t.item(),
h.item() // spatial_merge_size,
w.item() // spatial_merge_size,
)
text_len = ed - st
text_len -= skip_vision_start_token[image_index - 1]
text_len = max(0, text_len)
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten()
h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten()
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten()
if skip_vision_start_token[image_index - 1]:
if fix_point > 0:
fix_point = fix_point - st_idx
llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + fix_point + st_idx)
fix_point = 0
else:
llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx)
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
if st < len(input_tokens):
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
text_len = len(input_tokens) - st
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device)
mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i]))
mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1)
return position_ids, mrope_position_deltas
else:
if attention_mask is not None:
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device)
max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0]
mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1]
else:
position_ids = (
torch.arange(input_ids.shape[1], device=input_ids.device)
.view(1, 1, -1)
.expand(3, input_ids.shape[0], -1)
)
mrope_position_deltas = torch.zeros(
[input_ids.shape[0], 1],
device=input_ids.device,
dtype=input_ids.dtype,
)
return position_ids, mrope_position_deltas