| import os |
| import math |
| import torch |
| from typing import Optional |
| from PIL import Image, ImageDraw |
| import json |
| from typing import Any, Dict, Iterable, List, Sequence, Tuple |
|
|
| MAX_BOX = 5 |
| PREDEFINED_RESOLUTIONS = [ |
| (2048, 2048), |
| (2304, 1728), |
| (1728, 2304), |
| (2560, 1440), |
| (1440, 2560), |
| (2496, 1664), |
| (1664, 2496), |
| (3104, 1312), |
| (1312, 3104), |
| (2304, 1792), |
| (1792, 2304), |
| ] |
| DEFAULT_COLORS = [ |
| (255, 0, 0), |
| (0, 180, 0), |
| (0, 0, 255), |
| (204, 180, 0), |
| (255, 0, 255), |
| (0, 255, 255), |
| (128, 0, 0), |
| (0, 128, 0), |
| (0, 0, 128), |
| (128, 128, 0), |
| ] |
|
|
| def load_layout_bboxes(layout_bboxes: str) -> Any: |
| """Load layout boxes from either a JSON string or a JSON file path.""" |
| if os.path.exists(layout_bboxes): |
| with open(layout_bboxes, "r", encoding="utf-8") as f: |
| return json.load(f) |
| return json.loads(layout_bboxes) |
|
|
| def _unwrap_boxes(data: Any) -> Any: |
| if isinstance(data, dict): |
| for key in ("layout_bboxes", "bboxes", "boxes", "bbox_list"): |
| if key in data: |
| return data[key] |
| return data |
|
|
| def _as_bbox_and_text(item: Any) -> Tuple[Sequence[float], str]: |
| if isinstance(item, dict): |
| bbox = item.get("bbox") or item.get("box") |
| text = str(item.get("text") or item.get("label") or "") |
| if bbox is None: |
| raise ValueError(f"Missing bbox in layout item: {item!r}") |
| return bbox, text |
| if isinstance(item, (list, tuple)) and len(item) == 4: |
| return item, "" |
| raise ValueError(f"Unsupported layout bbox item: {item!r}") |
|
|
|
|
| def _xxyy_relative_to_absolute_bbox(bbox: Sequence[float], width: int, height: int) -> List[int]: |
| if len(bbox) != 4: |
| raise ValueError(f"Expected bbox with 4 values, got: {bbox!r}") |
| x1, x2, y1, y2 = [float(v) for v in bbox] |
|
|
| |
| |
| max_abs = max(abs(x1), abs(y1), abs(x2), abs(y2)) |
| if max_abs <= 1.0: |
| x1, x2 = x1 * width, x2 * width |
| y1, y2 = y1 * height, y2 * height |
| elif max_abs <= 100.0: |
| x1, x2 = x1 / 100.0 * width, x2 / 100.0 * width |
| y1, y2 = y1 / 100.0 * height, y2 / 100.0 * height |
|
|
| x1, x2 = sorted((x1, x2)) |
| y1, y2 = sorted((y1, y2)) |
| x1 = max(0, min(width - 1, int(round(x1)))) |
| y1 = max(0, min(height - 1, int(round(y1)))) |
| x2 = max(0, min(width - 1, int(round(x2)))) |
| y2 = max(0, min(height - 1, int(round(y2)))) |
| if x2 <= x1 or y2 <= y1: |
| raise ValueError(f"Invalid bbox after scaling/clamping: {[x1, y1, x2, y2]!r}") |
| return [x1, y1, x2, y2] |
|
|
| def parse_layout_bboxes(layout_bboxes: Any, width: int, height: int) -> List[Dict[str, Any]]: |
| """Convert xxyy relative layout boxes into the training-side bbox layout format.""" |
| raw_boxes = _unwrap_boxes(layout_bboxes) |
| if not isinstance(raw_boxes, list): |
| raise ValueError("layout_bboxes must be a list, or a dict containing one of: layout_bboxes/bboxes/boxes") |
|
|
| parsed = [] |
| for idx, item in enumerate(raw_boxes): |
| bbox, text = _as_bbox_and_text(item) |
| parsed.append({ |
| "bbox": _xxyy_relative_to_absolute_bbox(bbox, width, height), |
| "color": "", |
| "text": text, |
| "image": None, |
| "_orig_idx": idx, |
| }) |
| return parsed |
|
|
| def _bbox_area(item: Dict[str, Any]) -> int: |
| x1, y1, x2, y2 = item["bbox"] |
| return max(0, x2 - x1) * max(0, y2 - y1) |
|
|
| def get_render_params(image_width: int, image_height: int) -> Tuple[int, int]: |
| edge = math.sqrt(image_width * image_height) |
| max_font_size = int(edge * 0.07) |
| max_bbox_line_width = int(edge * 0.05) |
| return max_font_size, max_bbox_line_width |
|
|
| def draw_bbox_layout( |
| bbox_list: List[Dict[str, Any]], |
| image_width: int, |
| image_height: int, |
| max_bbox: int = MAX_BOX, |
| max_bbox_line_width: int | None = None, |
| bbox_line_gap: int | None = None, |
| return_color: bool = False, |
| ): |
| """Draw a black layout image with colored boxes, matching the training-side layout style.""" |
| if max_bbox_line_width is None: |
| _, max_bbox_line_width = get_render_params(image_width, image_height) |
| if bbox_line_gap is None: |
| bbox_line_gap = max(1, max_bbox_line_width // max_bbox) |
|
|
| image = Image.new("RGB", (image_width, image_height), (0, 0, 0)) |
| draw = ImageDraw.Draw(image) |
| color_list = [None] * len(bbox_list) |
| sorted_bboxes = sorted(bbox_list, key=_bbox_area, reverse=True)[:max_bbox] |
|
|
| for sorted_idx, item in enumerate(sorted_bboxes): |
| color = DEFAULT_COLORS[sorted_idx % len(DEFAULT_COLORS)] |
| orig_idx = int(item.get("_orig_idx", sorted_idx)) |
| if 0 <= orig_idx < len(color_list): |
| color_list[orig_idx] = color |
| line_width = max(max_bbox_line_width - sorted_idx * bbox_line_gap, 5) |
| draw.rectangle([int(v) for v in item["bbox"]], outline=color, width=line_width) |
|
|
| if return_color: |
| return image, color_list |
| return image |
|
|
| def add_outer_border_keep_size(pil: Image.Image, color: Iterable[int], width: int) -> Image.Image: |
| """Draw a border inside the image without changing its size.""" |
| img = pil.convert("RGB").copy() |
| color_tuple = tuple(int(c) for c in color) |
| width = max(0, int(width)) |
| if width == 0: |
| return img |
|
|
| draw = ImageDraw.Draw(img) |
| w, h = img.size |
| for t in range(width): |
| draw.rectangle([t, t, w - 1 - t, h - 1 - t], outline=color_tuple) |
| return img |
|
|
| def create_layout_reference_images( |
| ref_pils: Sequence[str], |
| layout_bboxes: Any, |
| image_width: int, |
| image_height: int, |
| ref_max_size: int | None = None, |
| patch_size: int = 32, |
| ) -> Tuple[List[str], str]: |
| """Create bordered ref images plus one layout image; returns paths to pass as ref_images.""" |
| parsed_boxes = parse_layout_bboxes(layout_bboxes, image_width, image_height) |
| layout_image, color_list = draw_bbox_layout( |
| parsed_boxes, |
| image_width=image_width, |
| image_height=image_height, |
| return_color=True, |
| ) |
|
|
| output_refs: List[str] = [] |
| for idx, ref in enumerate(ref_pils): |
| if ref_max_size is not None: |
| ref = resize_pilimage(ref, ref_max_size, patch_size) |
| color = color_list[idx] if idx < len(color_list) and color_list[idx] is not None else DEFAULT_COLORS[idx % len(DEFAULT_COLORS)] |
| line_width = int(math.sqrt(ref.width * ref.height) * 0.04) |
| bordered = add_outer_border_keep_size(ref, color, line_width) |
| output_refs.append(bordered) |
| output_refs.append(layout_image) |
| return output_refs |
|
|
|
|
| def find_closest_resolution(width, height): |
| img_ratio = width / height |
| best_res = None |
| min_diff = float("inf") |
| for w, h in PREDEFINED_RESOLUTIONS: |
| ratio = w / h |
| diff = abs(ratio - img_ratio) |
| if diff < min_diff: |
| min_diff = diff |
| best_res = (w, h) |
| return best_res |
|
|
| def resize_pilimage(pil_image, image_size, patch_size=16, resampler=Image.BICUBIC): |
| while min(*pil_image.size) >= 2 * image_size: |
| pil_image = pil_image.resize( |
| tuple(x // 2 for x in pil_image.size), resample=Image.BOX |
| ) |
|
|
| m = patch_size |
| width, height = pil_image.width, pil_image.height |
| S_max = image_size * image_size |
| scale = S_max / (width * height) |
| scale = math.sqrt(scale) |
|
|
| new_sizes = [ |
| (round(width * scale) // m * m, round(height * scale) // m * m), |
| (round(width * scale) // m * m, math.floor(height * scale) // m * m), |
| (math.floor(width * scale) // m * m, round(height * scale) // m * m), |
| (math.floor(width * scale) // m * m, math.floor(height * scale) // m * m), |
| ] |
| new_sizes = sorted(new_sizes, key=lambda x: x[0] * x[1], reverse=True) |
|
|
| for new_size in new_sizes: |
| if new_size[0] * new_size[1] <= S_max: |
| break |
|
|
| s1 = width / new_size[0] |
| s2 = height / new_size[1] |
| if s1 < s2: |
| pil_image = pil_image.resize([new_size[0], round(height / s1)], resample=resampler) |
| top = (round(height / s1) - new_size[1]) // 2 |
| pil_image = pil_image.crop((0, top, new_size[0], top + new_size[1])) |
| else: |
| pil_image = pil_image.resize([round(width / s2), new_size[1]], resample=resampler) |
| left = (round(width / s2) - new_size[0]) // 2 |
| pil_image = pil_image.crop((left, 0, left + new_size[0], new_size[1])) |
|
|
| return pil_image |
|
|
| def calculate_dimensions(max_size, ratio): |
| width = math.sqrt(max_size * max_size * ratio) |
| height = width / ratio |
| width = int(width / 32) * 32 |
| height = int(height / 32) * 32 |
| return width, height |
|
|
| def get_rope_index_fix_point( |
| spatial_merge_size, |
| image_token_id, |
| video_token_id, |
| vision_start_token_id, |
| input_ids: Optional[torch.LongTensor] = None, |
| image_grid_thw: Optional[torch.LongTensor] = None, |
| video_grid_thw: Optional[torch.LongTensor] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| skip_vision_start_token=None, |
| fix_point=4096, |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| if video_grid_thw is not None: |
| video_grid_thw = torch.repeat_interleave(video_grid_thw, video_grid_thw[:, 0], dim=0) |
| video_grid_thw[:, 0] = 1 |
|
|
| mrope_position_deltas = [] |
| if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None): |
| total_input_ids = input_ids |
| if attention_mask is None: |
| attention_mask = torch.ones_like(total_input_ids) |
| position_ids = torch.ones( |
| 3, |
| input_ids.shape[0], |
| input_ids.shape[1], |
| dtype=input_ids.dtype, |
| device=input_ids.device, |
| ) |
| image_index, video_index = 0, 0 |
| attention_mask = attention_mask.to(total_input_ids.device) |
| for i, input_ids in enumerate(total_input_ids): |
| input_ids = input_ids[attention_mask[i] == 1] |
| image_nums, video_nums = 0, 0 |
| vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1) |
| vision_tokens = input_ids[vision_start_indices + 1] |
| image_nums = (vision_tokens == image_token_id).sum() |
| video_nums = (vision_tokens == video_token_id).sum() |
| input_tokens = input_ids.tolist() |
| llm_pos_ids_list: list = [] |
| st = 0 |
| remain_images, remain_videos = image_nums, video_nums |
| for _ in range(image_nums + video_nums): |
| if image_token_id in input_tokens and remain_images > 0: |
| ed_image = input_tokens.index(image_token_id, st) |
| else: |
| ed_image = len(input_tokens) + 1 |
| if video_token_id in input_tokens and remain_videos > 0: |
| ed_video = input_tokens.index(video_token_id, st) |
| else: |
| ed_video = len(input_tokens) + 1 |
| if ed_image < ed_video: |
| t, h, w = ( |
| image_grid_thw[image_index][0], |
| image_grid_thw[image_index][1], |
| image_grid_thw[image_index][2], |
| ) |
| image_index += 1 |
| remain_images -= 1 |
| ed = ed_image |
| else: |
| t, h, w = ( |
| video_grid_thw[video_index][0], |
| video_grid_thw[video_index][1], |
| video_grid_thw[video_index][2], |
| ) |
| video_index += 1 |
| remain_videos -= 1 |
| ed = ed_video |
| llm_grid_t, llm_grid_h, llm_grid_w = ( |
| t.item(), |
| h.item() // spatial_merge_size, |
| w.item() // spatial_merge_size, |
| ) |
| text_len = ed - st |
|
|
| text_len -= skip_vision_start_token[image_index - 1] |
| text_len = max(0, text_len) |
|
|
| st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 |
| llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) |
|
|
| t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten() |
| h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten() |
| w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten() |
|
|
| if skip_vision_start_token[image_index - 1]: |
| if fix_point > 0: |
| fix_point = fix_point - st_idx |
| llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + fix_point + st_idx) |
| fix_point = 0 |
| else: |
| llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx) |
| st = ed + llm_grid_t * llm_grid_h * llm_grid_w |
|
|
| if st < len(input_tokens): |
| st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 |
| text_len = len(input_tokens) - st |
| llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) |
|
|
| llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) |
| position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device) |
| mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i])) |
| mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1) |
| return position_ids, mrope_position_deltas |
| else: |
| if attention_mask is not None: |
| position_ids = attention_mask.long().cumsum(-1) - 1 |
| position_ids.masked_fill_(attention_mask == 0, 1) |
| position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device) |
| max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0] |
| mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1] |
| else: |
| position_ids = ( |
| torch.arange(input_ids.shape[1], device=input_ids.device) |
| .view(1, 1, -1) |
| .expand(3, input_ids.shape[0], -1) |
| ) |
| mrope_position_deltas = torch.zeros( |
| [input_ids.shape[0], 1], |
| device=input_ids.device, |
| dtype=input_ids.dtype, |
| ) |
| return position_ids, mrope_position_deltas |
|
|