Lance / data /data_utils.py
Nayefleb's picture
Upload folder using huggingface_hub
8b306b3 verified
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# coding: utf-8
"""
Data helpers used by inference (`inference_lance.py`, `ValidationDataset`) and the
Lance model core (`modeling/lance/lance.py`).
Exported utilities:
- Position id helpers (image / video, interpolate / extrapolate)
- Patchify helpers (image + video-with-merge)
- create_sparse_mask : flex-attention sparse mask builder
- add_special_tokens : register chat / vision tokens on a tokenizer
- len2weight : CE loss reweighting factor
"""
from einops import rearrange
import torch
from torch.nn.attention.flex_attention import or_masks, and_masks
# ------------------------------------------------------------------
# Position id helpers
# ------------------------------------------------------------------
def get_flattened_position_ids_interpolate_video(num_frames, img_h, img_w, patch_size, max_num_frames, max_num_patches_per_side):
num_patches_h, num_patches_w = img_h // patch_size, img_w // patch_size
# temporal
boundaries_t = torch.arange(1 / max_num_frames, 1.0, 1 / max_num_frames)
fractional_coords_t = torch.arange(0, 1 - 1e-6, 1 / num_frames)
bucket_coords_t = torch.bucketize(fractional_coords_t, boundaries_t, right=True)
# spatial
boundaries_s = torch.arange(1 / max_num_patches_per_side, 1.0, 1 / max_num_patches_per_side)
fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / num_patches_h)
fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / num_patches_w)
bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries_s, right=True)
bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries_s, right=True)
pos_ids = (
bucket_coords_t[:, None, None] * max_num_patches_per_side * max_num_patches_per_side
+ bucket_coords_h[None, :, None] * max_num_patches_per_side
+ bucket_coords_w[None, None, :]
).flatten()
return pos_ids
def get_flattened_position_ids_extrapolate_video(t, h, w, max_latent_size):
"""
默认情况下:
num_frames = 7 (对应 25 frames)
max_num_patches_per_side = 64
"""
coords_t = torch.arange(0, t)
coords_h = torch.arange(0, h)
coords_w = torch.arange(0, w)
pos_ids = (
coords_t[:, None, None] * max_latent_size * max_latent_size
+ coords_h[None, :, None] * max_latent_size
+ coords_w[None, None, :]
).flatten()
return pos_ids
def get_flattened_position_ids_extrapolate(img_h, img_w, patch_size, max_num_patches_per_side):
num_patches_h, num_patches_w = img_h // patch_size, img_w // patch_size
coords_h = torch.arange(0, num_patches_h)
coords_w = torch.arange(0, num_patches_w)
pos_ids = (coords_h[:, None] * max_num_patches_per_side + coords_w).flatten()
return pos_ids
def get_flattened_position_ids_interpolate(img_h, img_w, patch_size, max_num_patches_per_side):
num_patches_h, num_patches_w = img_h // patch_size, img_w // patch_size
boundaries = torch.arange(1 / max_num_patches_per_side, 1.0, 1 / max_num_patches_per_side)
fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / num_patches_h)
fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / num_patches_w)
bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True)
bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True)
pos_ids = (bucket_coords_h[:, None] * max_num_patches_per_side + bucket_coords_w).flatten()
return pos_ids
# ------------------------------------------------------------------
# Patchify helpers
# ------------------------------------------------------------------
def patchify(image, patch_size):
p = patch_size
c, h, w = image.shape
assert h % p == 0 and w % p == 0
image = image.reshape(c, h // p, p, w // p, p)
image = torch.einsum("chpwq->hwpqc", image)
image = image.reshape(-1, p**2 * c)
return image
def patchify_video_with_merge(video, spatial_patch_size, temporal_patch_size, merge_size=2):
"""
Args:
video: Tensor of shape [C, T, H, W]
spatial_patch_size: patch size for H/W
temporal_patch_size: patch size for T
merge_size: merging factor for spatial grid (固定为 2)
Returns:
patches: Tensor of shape [num_patches, patch_dim]
"""
video = rearrange(video, "C T H W -> T C H W")
T, C, H, W = video.shape
p, tp, ms = spatial_patch_size, temporal_patch_size, merge_size
gt, gh, gw = T // tp, H // p, W // p
video = video.reshape(gt, tp, C, gh // ms, ms, p, gw // ms, ms, p)
video = video.permute(0, 3, 6, 4, 7, 2, 1, 5, 8)
patches = video.reshape(gt * gh * gw, C * tp * p * p)
return patches
# ------------------------------------------------------------------
# Sparse attention mask (flex-attention)
# ------------------------------------------------------------------
def create_sparse_mask(document_lens, split_lens, attn_modes, device):
def causal_mask(b, h, q_idx, kv_idx):
return q_idx >= kv_idx
def full_and_noise_mask(b, h, q_idx, kv_idx):
return (full_and_noise_seq_id[q_idx] == full_and_noise_seq_id[kv_idx]) & (full_and_noise_seq_id[q_idx] >= 0)
def remove_noise_mask(b, h, q_idx, kv_idx):
return ~((noise_seq_id[kv_idx] >= 0) & (noise_seq_id[q_idx] != noise_seq_id[kv_idx]))
def sample_mask(b, h, q_idx, kv_idx):
return document_id[q_idx] == document_id[kv_idx]
full_and_noise_tmp = []
noise_tmp = []
for i, (length, mode) in enumerate(zip(split_lens, attn_modes)):
value = i if mode in ["full", "noise"] else -1
full_and_noise_tmp.extend([value] * length)
value_noise = i if mode == "noise" else -1
noise_tmp.extend([value_noise] * length)
full_and_noise_seq_id = torch.Tensor(full_and_noise_tmp).to(device)
noise_seq_id = torch.Tensor(noise_tmp).to(device)
document_id = torch.cat([torch.full((l,), i) for i, l in enumerate(document_lens, start=1)]).to(device)
return and_masks(or_masks(causal_mask, full_and_noise_mask), remove_noise_mask, sample_mask)
# ------------------------------------------------------------------
# Tokenizer / loss helpers
# ------------------------------------------------------------------
def add_special_tokens(tokenizer):
all_special_tokens = []
for k, v in tokenizer.special_tokens_map.items():
if isinstance(v, str):
all_special_tokens.append(v)
elif isinstance(v, list):
all_special_tokens += v
new_tokens = []
for tok in ("<|im_start|>", "<|im_end|>", "<|vision_start|>", "<|vision_end|>"):
if tok not in all_special_tokens:
new_tokens.append(tok)
num_new_tokens = tokenizer.add_tokens(new_tokens)
new_token_ids = dict(
bos_token_id=tokenizer.convert_tokens_to_ids("<|im_start|>"),
eos_token_id=tokenizer.convert_tokens_to_ids("<|im_end|>"),
start_of_image=tokenizer.convert_tokens_to_ids("<|vision_start|>"),
end_of_image=tokenizer.convert_tokens_to_ids("<|vision_end|>"),
)
return tokenizer, new_token_ids, num_new_tokens
def len2weight(x, loss_reduction="square"):
if x == 0:
return x
if loss_reduction == "token":
return 1
if loss_reduction == "sample":
return 1 / x
if loss_reduction == "square":
return 1 / (x**0.5)
raise NotImplementedError(loss_reduction)