| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| """ |
| Data helpers used by inference (`inference_lance.py`, `ValidationDataset`) and the |
| Lance model core (`modeling/lance/lance.py`). |
| |
| Exported utilities: |
| - Position id helpers (image / video, interpolate / extrapolate) |
| - Patchify helpers (image + video-with-merge) |
| - create_sparse_mask : flex-attention sparse mask builder |
| - add_special_tokens : register chat / vision tokens on a tokenizer |
| - len2weight : CE loss reweighting factor |
| """ |
|
|
| from einops import rearrange |
|
|
| import torch |
| from torch.nn.attention.flex_attention import or_masks, and_masks |
|
|
|
|
| |
| |
| |
|
|
| def get_flattened_position_ids_interpolate_video(num_frames, img_h, img_w, patch_size, max_num_frames, max_num_patches_per_side): |
| num_patches_h, num_patches_w = img_h // patch_size, img_w // patch_size |
| |
| boundaries_t = torch.arange(1 / max_num_frames, 1.0, 1 / max_num_frames) |
| fractional_coords_t = torch.arange(0, 1 - 1e-6, 1 / num_frames) |
| bucket_coords_t = torch.bucketize(fractional_coords_t, boundaries_t, right=True) |
| |
| boundaries_s = torch.arange(1 / max_num_patches_per_side, 1.0, 1 / max_num_patches_per_side) |
| fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / num_patches_h) |
| fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / num_patches_w) |
| bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries_s, right=True) |
| bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries_s, right=True) |
| pos_ids = ( |
| bucket_coords_t[:, None, None] * max_num_patches_per_side * max_num_patches_per_side |
| + bucket_coords_h[None, :, None] * max_num_patches_per_side |
| + bucket_coords_w[None, None, :] |
| ).flatten() |
| return pos_ids |
|
|
|
|
| def get_flattened_position_ids_extrapolate_video(t, h, w, max_latent_size): |
| """ |
| 默认情况下: |
| num_frames = 7 (对应 25 frames) |
| max_num_patches_per_side = 64 |
| """ |
| coords_t = torch.arange(0, t) |
| coords_h = torch.arange(0, h) |
| coords_w = torch.arange(0, w) |
| pos_ids = ( |
| coords_t[:, None, None] * max_latent_size * max_latent_size |
| + coords_h[None, :, None] * max_latent_size |
| + coords_w[None, None, :] |
| ).flatten() |
| return pos_ids |
|
|
|
|
| def get_flattened_position_ids_extrapolate(img_h, img_w, patch_size, max_num_patches_per_side): |
| num_patches_h, num_patches_w = img_h // patch_size, img_w // patch_size |
| coords_h = torch.arange(0, num_patches_h) |
| coords_w = torch.arange(0, num_patches_w) |
| pos_ids = (coords_h[:, None] * max_num_patches_per_side + coords_w).flatten() |
| return pos_ids |
|
|
|
|
| def get_flattened_position_ids_interpolate(img_h, img_w, patch_size, max_num_patches_per_side): |
| num_patches_h, num_patches_w = img_h // patch_size, img_w // patch_size |
| boundaries = torch.arange(1 / max_num_patches_per_side, 1.0, 1 / max_num_patches_per_side) |
| fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / num_patches_h) |
| fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / num_patches_w) |
| bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True) |
| bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True) |
| pos_ids = (bucket_coords_h[:, None] * max_num_patches_per_side + bucket_coords_w).flatten() |
| return pos_ids |
|
|
|
|
| |
| |
| |
|
|
| def patchify(image, patch_size): |
| p = patch_size |
| c, h, w = image.shape |
| assert h % p == 0 and w % p == 0 |
| image = image.reshape(c, h // p, p, w // p, p) |
| image = torch.einsum("chpwq->hwpqc", image) |
| image = image.reshape(-1, p**2 * c) |
| return image |
|
|
|
|
| def patchify_video_with_merge(video, spatial_patch_size, temporal_patch_size, merge_size=2): |
| """ |
| Args: |
| video: Tensor of shape [C, T, H, W] |
| spatial_patch_size: patch size for H/W |
| temporal_patch_size: patch size for T |
| merge_size: merging factor for spatial grid (固定为 2) |
| |
| Returns: |
| patches: Tensor of shape [num_patches, patch_dim] |
| """ |
| video = rearrange(video, "C T H W -> T C H W") |
| T, C, H, W = video.shape |
| p, tp, ms = spatial_patch_size, temporal_patch_size, merge_size |
|
|
| gt, gh, gw = T // tp, H // p, W // p |
| video = video.reshape(gt, tp, C, gh // ms, ms, p, gw // ms, ms, p) |
| video = video.permute(0, 3, 6, 4, 7, 2, 1, 5, 8) |
| patches = video.reshape(gt * gh * gw, C * tp * p * p) |
| return patches |
|
|
|
|
| |
| |
| |
|
|
| def create_sparse_mask(document_lens, split_lens, attn_modes, device): |
| def causal_mask(b, h, q_idx, kv_idx): |
| return q_idx >= kv_idx |
|
|
| def full_and_noise_mask(b, h, q_idx, kv_idx): |
| return (full_and_noise_seq_id[q_idx] == full_and_noise_seq_id[kv_idx]) & (full_and_noise_seq_id[q_idx] >= 0) |
|
|
| def remove_noise_mask(b, h, q_idx, kv_idx): |
| return ~((noise_seq_id[kv_idx] >= 0) & (noise_seq_id[q_idx] != noise_seq_id[kv_idx])) |
|
|
| def sample_mask(b, h, q_idx, kv_idx): |
| return document_id[q_idx] == document_id[kv_idx] |
|
|
| full_and_noise_tmp = [] |
| noise_tmp = [] |
|
|
| for i, (length, mode) in enumerate(zip(split_lens, attn_modes)): |
| value = i if mode in ["full", "noise"] else -1 |
| full_and_noise_tmp.extend([value] * length) |
| value_noise = i if mode == "noise" else -1 |
| noise_tmp.extend([value_noise] * length) |
|
|
| full_and_noise_seq_id = torch.Tensor(full_and_noise_tmp).to(device) |
| noise_seq_id = torch.Tensor(noise_tmp).to(device) |
|
|
| document_id = torch.cat([torch.full((l,), i) for i, l in enumerate(document_lens, start=1)]).to(device) |
|
|
| return and_masks(or_masks(causal_mask, full_and_noise_mask), remove_noise_mask, sample_mask) |
|
|
|
|
| |
| |
| |
|
|
| def add_special_tokens(tokenizer): |
| all_special_tokens = [] |
| for k, v in tokenizer.special_tokens_map.items(): |
| if isinstance(v, str): |
| all_special_tokens.append(v) |
| elif isinstance(v, list): |
| all_special_tokens += v |
|
|
| new_tokens = [] |
| for tok in ("<|im_start|>", "<|im_end|>", "<|vision_start|>", "<|vision_end|>"): |
| if tok not in all_special_tokens: |
| new_tokens.append(tok) |
|
|
| num_new_tokens = tokenizer.add_tokens(new_tokens) |
| new_token_ids = dict( |
| bos_token_id=tokenizer.convert_tokens_to_ids("<|im_start|>"), |
| eos_token_id=tokenizer.convert_tokens_to_ids("<|im_end|>"), |
| start_of_image=tokenizer.convert_tokens_to_ids("<|vision_start|>"), |
| end_of_image=tokenizer.convert_tokens_to_ids("<|vision_end|>"), |
| ) |
| return tokenizer, new_token_ids, num_new_tokens |
|
|
|
|
| def len2weight(x, loss_reduction="square"): |
| if x == 0: |
| return x |
| if loss_reduction == "token": |
| return 1 |
| if loss_reduction == "sample": |
| return 1 / x |
| if loss_reduction == "square": |
| return 1 / (x**0.5) |
| raise NotImplementedError(loss_reduction) |
|
|