# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # coding: utf-8 """ Data helpers used by inference (`inference_lance.py`, `ValidationDataset`) and the Lance model core (`modeling/lance/lance.py`). Exported utilities: - Position id helpers (image / video, interpolate / extrapolate) - Patchify helpers (image + video-with-merge) - create_sparse_mask : flex-attention sparse mask builder - add_special_tokens : register chat / vision tokens on a tokenizer - len2weight : CE loss reweighting factor """ from einops import rearrange import torch from torch.nn.attention.flex_attention import or_masks, and_masks # ------------------------------------------------------------------ # Position id helpers # ------------------------------------------------------------------ def get_flattened_position_ids_interpolate_video(num_frames, img_h, img_w, patch_size, max_num_frames, max_num_patches_per_side): num_patches_h, num_patches_w = img_h // patch_size, img_w // patch_size # temporal boundaries_t = torch.arange(1 / max_num_frames, 1.0, 1 / max_num_frames) fractional_coords_t = torch.arange(0, 1 - 1e-6, 1 / num_frames) bucket_coords_t = torch.bucketize(fractional_coords_t, boundaries_t, right=True) # spatial boundaries_s = torch.arange(1 / max_num_patches_per_side, 1.0, 1 / max_num_patches_per_side) fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / num_patches_h) fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / num_patches_w) bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries_s, right=True) bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries_s, right=True) pos_ids = ( bucket_coords_t[:, None, None] * max_num_patches_per_side * max_num_patches_per_side + bucket_coords_h[None, :, None] * max_num_patches_per_side + bucket_coords_w[None, None, :] ).flatten() return pos_ids def get_flattened_position_ids_extrapolate_video(t, h, w, max_latent_size): """ 默认情况下: num_frames = 7 (对应 25 frames) max_num_patches_per_side = 64 """ coords_t = torch.arange(0, t) coords_h = torch.arange(0, h) coords_w = torch.arange(0, w) pos_ids = ( coords_t[:, None, None] * max_latent_size * max_latent_size + coords_h[None, :, None] * max_latent_size + coords_w[None, None, :] ).flatten() return pos_ids def get_flattened_position_ids_extrapolate(img_h, img_w, patch_size, max_num_patches_per_side): num_patches_h, num_patches_w = img_h // patch_size, img_w // patch_size coords_h = torch.arange(0, num_patches_h) coords_w = torch.arange(0, num_patches_w) pos_ids = (coords_h[:, None] * max_num_patches_per_side + coords_w).flatten() return pos_ids def get_flattened_position_ids_interpolate(img_h, img_w, patch_size, max_num_patches_per_side): num_patches_h, num_patches_w = img_h // patch_size, img_w // patch_size boundaries = torch.arange(1 / max_num_patches_per_side, 1.0, 1 / max_num_patches_per_side) fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / num_patches_h) fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / num_patches_w) bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True) bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True) pos_ids = (bucket_coords_h[:, None] * max_num_patches_per_side + bucket_coords_w).flatten() return pos_ids # ------------------------------------------------------------------ # Patchify helpers # ------------------------------------------------------------------ def patchify(image, patch_size): p = patch_size c, h, w = image.shape assert h % p == 0 and w % p == 0 image = image.reshape(c, h // p, p, w // p, p) image = torch.einsum("chpwq->hwpqc", image) image = image.reshape(-1, p**2 * c) return image def patchify_video_with_merge(video, spatial_patch_size, temporal_patch_size, merge_size=2): """ Args: video: Tensor of shape [C, T, H, W] spatial_patch_size: patch size for H/W temporal_patch_size: patch size for T merge_size: merging factor for spatial grid (固定为 2) Returns: patches: Tensor of shape [num_patches, patch_dim] """ video = rearrange(video, "C T H W -> T C H W") T, C, H, W = video.shape p, tp, ms = spatial_patch_size, temporal_patch_size, merge_size gt, gh, gw = T // tp, H // p, W // p video = video.reshape(gt, tp, C, gh // ms, ms, p, gw // ms, ms, p) video = video.permute(0, 3, 6, 4, 7, 2, 1, 5, 8) patches = video.reshape(gt * gh * gw, C * tp * p * p) return patches # ------------------------------------------------------------------ # Sparse attention mask (flex-attention) # ------------------------------------------------------------------ def create_sparse_mask(document_lens, split_lens, attn_modes, device): def causal_mask(b, h, q_idx, kv_idx): return q_idx >= kv_idx def full_and_noise_mask(b, h, q_idx, kv_idx): return (full_and_noise_seq_id[q_idx] == full_and_noise_seq_id[kv_idx]) & (full_and_noise_seq_id[q_idx] >= 0) def remove_noise_mask(b, h, q_idx, kv_idx): return ~((noise_seq_id[kv_idx] >= 0) & (noise_seq_id[q_idx] != noise_seq_id[kv_idx])) def sample_mask(b, h, q_idx, kv_idx): return document_id[q_idx] == document_id[kv_idx] full_and_noise_tmp = [] noise_tmp = [] for i, (length, mode) in enumerate(zip(split_lens, attn_modes)): value = i if mode in ["full", "noise"] else -1 full_and_noise_tmp.extend([value] * length) value_noise = i if mode == "noise" else -1 noise_tmp.extend([value_noise] * length) full_and_noise_seq_id = torch.Tensor(full_and_noise_tmp).to(device) noise_seq_id = torch.Tensor(noise_tmp).to(device) document_id = torch.cat([torch.full((l,), i) for i, l in enumerate(document_lens, start=1)]).to(device) return and_masks(or_masks(causal_mask, full_and_noise_mask), remove_noise_mask, sample_mask) # ------------------------------------------------------------------ # Tokenizer / loss helpers # ------------------------------------------------------------------ def add_special_tokens(tokenizer): all_special_tokens = [] for k, v in tokenizer.special_tokens_map.items(): if isinstance(v, str): all_special_tokens.append(v) elif isinstance(v, list): all_special_tokens += v new_tokens = [] for tok in ("<|im_start|>", "<|im_end|>", "<|vision_start|>", "<|vision_end|>"): if tok not in all_special_tokens: new_tokens.append(tok) num_new_tokens = tokenizer.add_tokens(new_tokens) new_token_ids = dict( bos_token_id=tokenizer.convert_tokens_to_ids("<|im_start|>"), eos_token_id=tokenizer.convert_tokens_to_ids("<|im_end|>"), start_of_image=tokenizer.convert_tokens_to_ids("<|vision_start|>"), end_of_image=tokenizer.convert_tokens_to_ids("<|vision_end|>"), ) return tokenizer, new_token_ids, num_new_tokens def len2weight(x, loss_reduction="square"): if x == 0: return x if loss_reduction == "token": return 1 if loss_reduction == "sample": return 1 / x if loss_reduction == "square": return 1 / (x**0.5) raise NotImplementedError(loss_reduction)