Lance / data /video /sampler /frames.py
Nayefleb's picture
Upload folder using huggingface_hub
8b306b3 verified
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# coding: utf-8
"""
Frame samplers.
TODO: 可能需要写一下满足自定义需求的frame sampler
"""
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Dict, List, Literal, NamedTuple, Optional, Tuple, Union
import numpy as np
class FrameSamplerOutput(NamedTuple):
"""
Return indices for frame decoding,
and optionally additional information to return to user.
"""
indices: List[int]
additional_info: Dict[str, Any] = {}
class FrameSampler(ABC):
"""
Frame sampler base class.
Child class must implement __call__ method to return the decoding indices.
Or raise if the video cannot be sampled (e.g. too short, etc.)
"""
@abstractmethod
def __call__(self, num_frames: int) -> FrameSamplerOutput:
raise NotImplementedError
class AllFrameSampler(FrameSampler):
"""
All frame sampler. Returns all frames in a video.
"""
def __call__(self, num_frames: int) -> FrameSamplerOutput:
return FrameSamplerOutput(indices=list(range(num_frames)))
class OnlyFirstFrameSampler:
"""
Only first frame sampler. Returns only the first frame of a video.
"""
def __call__(self, frames_info: Dict[str, int], **kwargs) -> FrameSamplerOutput:
return FrameSamplerOutput(indices=[0])
class FixedFrameSampler:
"""
固定帧数采样器(上/下采样统一算法):
- 接受包含 start_frame, end_frame, total_frames 的 frames_info dict;
- 对任意 total_frames ≥ 1,总是返回长度为 num_frames 的帧编号列表;
- 保证首尾对应 start_frame 和 end_frame - 1,内部等距离分布;
- 当 total_frames < num_frames 时会重复索引,如 [0,1,2] → [0,0,1,1,2,2]。
"""
def __init__(self, num_frames: int):
if num_frames < 1:
raise ValueError("num_frames must be ≥ 1")
self.num_frames = num_frames
def __call__(self, frames_info: Dict[str, int]) -> List[int]:
"""
参数:
frames_info: 包含 'start_frame', 'end_frame', 'total_frames' 的字典
返回:
List[int]: 采样后的全局帧编号列表,长度恒为 num_frames
"""
start = frames_info.get('start_frame')
total = frames_info.get('total_frames')
end = frames_info.get('end_frame')
if start is None or total is None or end is None:
raise ValueError("frames_info must contain 'start_frame', 'end_frame', and 'total_frames'")
if total < 1:
raise ValueError("total_frames must be ≥ 1")
# 计算相对索引
rel_indices = self._get_indices(total)
# 转换为全局并确保不越界
indices = [min(start + idx, end - 1) for idx in rel_indices]
return FrameSamplerOutput(
indices=indices,
additional_info={
"start_frame": start,
"end_frame": end,
"total_frames": total,
},
)
def _get_indices(self, total: int) -> List[int]:
# 单帧特殊处理
if self.num_frames == 1:
return [0]
# 统一采样公式,包括上采样和下采样场景
return [
int(round(i * (total - 1) / (self.num_frames - 1)))
for i in range(self.num_frames)
]
class ConsecutiveFrameSampler(FrameSampler):
"""
Adaptive frame sampler.
Arguments:
stride: frame skip.
For example, 1 denotes no skip. 2 denotes select every other frame. 3
denotes select every third frame. When a list is given, stride is randomly
chosen with even probability. However, user may set it to [1,1,2] to
denote 1 with 66% probability and 2 with 33% proability.
clip: clip location.
"center": clip video at the center.
"uniform": clip video uniformly at random.
jitter: jitter to the location.
Only applicable when clip is "center".
The value is the stdev of the normal distribution to shift the index.
"""
def __init__(
self,
strides: Union[int, List[int]] = 1,
temporal: int = 4,
clip: Literal["center", "uniform"] = "uniform",
jitter: float = 0.0,
):
strides = [strides] if isinstance(strides, int) else strides
assert len(strides) > 0
self.strides = np.array(strides)
self.temporal = temporal
self.clip = clip
self.jitter = jitter
def __call__(self, frames_info: Dict[str, int]) -> FrameSamplerOutput:
start_frame = frames_info["start_frame"]
end_frame = frames_info["end_frame"]
num_frames = frames_info["total_frames"]
stride = np.random.choice(self.strides)
frames = end_frame - start_frame
length = frames // stride
# Calculate the maximum integer of the form kn + 1 that does not exceed the given length.
def _max_kn_plus_1(length, k):
if length < 1:
raise ValueError("Length must be at least 1.")
n = (length - 1) // k
return k * n + 1
length = _max_kn_plus_1(length, self.temporal)
# Choose start index.
min_start_index = start_frame
max_start_index = end_frame - 1 - stride * (length - 1)
mid_start_index = round((min_start_index + max_start_index) / 2)
jitter = round(np.random.normal(loc=0, scale=self.jitter))
if self.clip == "head":
start_index = min_start_index
elif self.clip == "tail":
start_index = max_start_index
elif self.clip == "center":
start_index = mid_start_index + jitter
elif self.clip == "uniform":
start_index = np.random.randint(min_start_index, max_start_index + 1)
else:
raise NotImplementedError
start_index = np.clip(start_index, min_start_index, max_start_index)
# Compute indices
indices = np.arange(start_index, start_index + length * stride, stride)
# Return indices and additional information to return to user.
return FrameSamplerOutput(
indices=indices.tolist(),
additional_info={
"stride": stride,
"start_frame": start_index,
"end_frame": start_index + length * stride,
"total_frames": num_frames,
},
)
class AdaptiveFrameSampler(FrameSampler):
"""
Adaptive frame sampler.
Arguments:
length: frame length to return.
For example, [5,10] denotes to always return 5 frames or 10 frames.
It will choose the longest length that fits the original video.
For example, if the video is 9 frames total, it will clip to 5 frames.
stride: frame skip.
For example, 1 denotes no skip. 2 denotes select every other frame. 3
denotes select every third frame. When a list is given, stride is randomly
chosen with even probability. However, user may set it to [1,1,2] to
denote 1 with 66% probability and 2 with 33% proability.
clip: clip location.
"center": clip video at the center.
"uniform": clip video uniformly at random.
jitter: jitter to the location.
Only applicable when clip is "center".
The value is the stdev of the normal distribution to shift the index.
"""
def __init__(
self,
lengths: Union[int, List[int]],
strides: Union[int, List[int]] = 1,
clip: Literal["center", "uniform"] = "uniform",
jitter: float = 0.0,
):
lengths = [lengths] if isinstance(lengths, int) else lengths
strides = [strides] if isinstance(strides, int) else strides
assert len(lengths) > 0
assert len(strides) > 0
assert clip in ["center", "uniform"]
assert jitter >= 0
self.lengths = np.array(lengths)
self.strides = np.array(strides)
self.clip = clip
self.jitter = jitter
def __call__(
self,
num_frames: int,
) -> FrameSamplerOutput:
# Choose stride.
# Drop strides that are too long for this video.
# Then randomly choose a valid stride.
valid_strides = np.any(num_frames // self.strides >= self.lengths.reshape(-1, 1), axis=0)
valid_strides = self.strides[valid_strides]
if valid_strides.size <= 0:
raise ValueError(f"Video is too short ({num_frames} frames).")
stride = np.random.choice(valid_strides)
# Choose length.
# Pick the max length that can fit the video under the current stride.
valid_lengths = self.lengths[num_frames // stride >= self.lengths]
length = np.max(valid_lengths)
# Choose start index.
min_start_index = 0
max_start_index = num_frames - 1 - stride * (length - 1)
mid_start_index = round((min_start_index + max_start_index) / 2)
jitter = round(np.random.normal(loc=0, scale=self.jitter))
if self.clip == "center":
start_index = mid_start_index + jitter
elif self.clip == "uniform":
start_index = np.random.randint(min_start_index, max_start_index + 1)
else:
raise NotImplementedError
start_index = np.clip(start_index, min_start_index, max_start_index)
# Compute indices
indices = np.arange(start_index, start_index + length * stride, stride)
# Return indices and additional information to return to user.
return FrameSamplerOutput(
indices=indices.tolist(),
additional_info={
"stride": stride,
"start_frame": start_index,
"end_frame": start_index + length * stride,
"total_frames": num_frames,
},
)
@dataclass
class AdaptiveAdvancedFrameSamplerStrategy:
stride: int
stride_prob: float
frame_lengths: List[int]
frame_lengths_prob: Union[Literal["uniform", "harmonic"], List[float]]
class AdaptiveAdvancedFrameSampler(FrameSampler):
"""
Advanced adaptive frame sampler supports different frame lengths for different strides,
and supports probabilistic sampling of both the stride and the frame length.
strategies: A list of strategies to sample from.
clip: clip location.
"center": clip video at the center.
"uniform": clip video uniformly at random.
jitter: jitter to the location.
Only applicable when clip is "center".
The value is the stdev of the normal distribution to shift the index.
"""
def __init__(
self,
strategies: List[AdaptiveAdvancedFrameSamplerStrategy],
clip: Literal["center", "uniform"] = "uniform",
jitter: float = 0.0,
):
assert len(strategies) > 0, "Strategies must not be empty"
assert len({s.stride for s in strategies}) == len(strategies), "Strides cannot duplicate."
assert clip in ["center", "uniform"]
assert jitter >= 0
self.clip = clip
self.jitter = jitter
self.strides = []
self.strides_prob = []
self.frame_lengths = []
self.frame_lengths_prob = []
for strategy in sorted(strategies, key=lambda s: s.stride):
# Validate strides.
assert isinstance(strategy.stride, int), "Stride must be an integer."
assert strategy.stride > 0, "Stride must be a positive integer."
self.strides.append(strategy.stride)
# Assign strides_prob.
assert isinstance(strategy.stride_prob, (int, float)), "Stride prob is not int/float."
assert strategy.stride_prob >= 0, "Stride prob must be non-negative."
self.strides_prob.append(strategy.stride_prob)
# Assign frame lengths, sort by value.
assert len(strategy.frame_lengths) > 0, "Frame lengths must not be empty."
frame_lengths = np.array(strategy.frame_lengths)
assert frame_lengths.dtype == int, "Frame lengths must be integers."
assert np.all(frame_lengths > 0), "Frame lengths must be positive integers."
frame_lengths_sorted_idx = np.argsort(frame_lengths)
frame_lengths = frame_lengths[frame_lengths_sorted_idx]
self.frame_lengths.append(frame_lengths)
# Assign frame lengths prob, apply the sorting to prob as well.
if strategy.frame_lengths_prob == "uniform":
# e.g. [0.2, 0.2, 0.2, 0.2, 0.2]
frame_lengths_prob = np.full(len(frame_lengths), 1.0 / len(frame_lengths))
elif strategy.frame_lengths_prob == "harmonic":
# e.g. [0.2, 0.25, 0.33, 0.5, 1]
frame_lengths_prob = np.flip(1 / np.arange(1, len(frame_lengths) + 1))
elif isinstance(strategy.frame_lengths_prob, list):
frame_lengths_prob = np.array(strategy.frame_lengths_prob)
frame_lengths_prob = frame_lengths_prob[frame_lengths_sorted_idx]
else:
raise NotImplementedError
assert len(frame_lengths_prob) == len(frame_lengths), "Frame lengths prob mismatch."
assert np.all(frame_lengths_prob >= 0), "Frame lengths prob must not be negative."
assert frame_lengths_prob.sum() > 0, "Frame lengths prob must not be all zeros."
frame_lengths_prob /= frame_lengths_prob.sum()
self.frame_lengths_prob.append(frame_lengths_prob)
self.strides = np.array(self.strides)
self.strides_prob = np.array(self.strides_prob)
assert self.strides_prob.sum() > 0, "Strides prob must not be all zeros."
self.strides_prob /= self.strides_prob.sum()
def __call__(self, num_frames: int):
sample_result = adptive_sample_framelen_and_stride(
num_frames=num_frames,
strides=self.strides,
strides_prob=self.strides_prob,
frame_lengths=self.frame_lengths,
frame_lengths_prob=self.frame_lengths_prob,
)
stride = sample_result["stride"]
length = sample_result["frame_length"]
# Choose start index.
min_start_index = 0
max_start_index = num_frames - 1 - stride * (length - 1)
mid_start_index = round((min_start_index + max_start_index) / 2)
jitter = round(np.random.normal(loc=0, scale=self.jitter))
if self.clip == "center":
start_index = mid_start_index + jitter
elif self.clip == "uniform":
start_index = np.random.randint(min_start_index, max_start_index + 1)
else:
raise NotImplementedError
start_index = np.clip(start_index, min_start_index, max_start_index)
# Compute indices
indices = np.arange(start_index, start_index + length * stride, stride)
# Return indices and additional information to return to user.
return FrameSamplerOutput(
indices=indices.tolist(),
additional_info={
"stride": stride,
"start_frame": start_index,
"end_frame": start_index + length * stride,
"total_frames": num_frames,
},
)
class MultiClipsFrameSampler(FrameSampler):
"""
multi clips frame sampler.
Arguments:
temporal: downsample factor on temporal
sample_fps: fps of sampled frames
truncate: whether to truncate by max duration of the video (default = false, already truncate in clip_indices)
max_duration: truncate by max duration of the video
"""
def __init__(
self,
temporal: int = 4,
sample_fps: int = 12,
truncate: bool = False,
max_duration: int = 12,
length_type: Literal["kn", "kn+1"] = "kn+1",
assert_seconds: bool = True,
):
self.temporal = temporal
self.sample_fps = sample_fps
self.truncate = truncate
self.max_duration = max_duration
self.length_type = length_type
self.assert_seconds = assert_seconds
def __call__(self, frames_info: Dict[str, int]) -> FrameSamplerOutput:
clip_indices = frames_info["clip_indices"]
origin_fps = frames_info["fps"]
if self.truncate:
clip_indices = self.truncate_to_bucket(clip_indices, origin_fps)
if self.assert_seconds:
duration_sec = int(round(sum([(end - start) / origin_fps for start, end in clip_indices])))
if not self.truncate: # 新增:即使不截段也限制总时长
duration_sec = min(duration_sec, self.max_duration)
duration = int(round(duration_sec))
n_frames = duration * self.sample_fps
if self.length_type == "kn+1":
n_frames += 1
else:
duration = sum([(end - start) / origin_fps for start, end in clip_indices])
if not self.truncate: # 新增
duration = min(duration, self.max_duration)
n_frames = int(round(duration * self.sample_fps))
if self.length_type == "kn+1":
if n_frames % self.temporal != 0:
n_frames = n_frames // self.temporal * self.temporal + 1
else:
n_frames = n_frames // self.temporal * self.temporal + 1 - self.temporal
clip_n_frames = self.split_n_frames_by_clip(n_frames, clip_indices)
sample_indices = self.sample_frame_indices(clip_indices, clip_n_frames)
clip_n_latent_frames = [(n + self.temporal - 1) // self.temporal for n in clip_n_frames]
return FrameSamplerOutput(
indices=sample_indices,
additional_info={
"clip_n_frames": clip_n_frames,
"clip_n_latent_frames": clip_n_latent_frames,
},
)
def truncate_to_bucket(self, clip_indices, fps):
clip_indices = [tuple(index) for index in clip_indices]
durations = []
for start, end in clip_indices:
durations.append((end - start) / fps)
duration = sum(durations)
max_duration = min(int(duration), self.max_duration)
cutoff = duration - max_duration
if cutoff > 0:
if durations[-1] - cutoff > durations[0] - cutoff: # 截掉尾部
start, end = clip_indices[-1]
end = min(round((durations[-1] - cutoff) * fps), end) + start
clip_indices[-1] = (start, end)
else:
start, end = clip_indices[0]
start = max(end - round((durations[0] - cutoff) * fps), start)
clip_indices[0] = (start, end)
return clip_indices
def split_n_frames_by_clip(self, n_frames, clip_indices):
n_latent_frames = n_frames // self.temporal
clip_lengths = [(end - start) for start, end in clip_indices]
clip_n_latent_frames = [int(l / sum(clip_lengths) * n_latent_frames) for l in clip_lengths]
n_remains = n_latent_frames - sum(clip_n_latent_frames)
for i in range(n_remains):
clip_n_latent_frames[i] += 1
clip_n_frames = [n * self.temporal for n in clip_n_latent_frames]
if self.length_type == "kn+1":
clip_n_frames[0] += 1
return clip_n_frames
def sample_frame_indices(self, clip_indices, clip_n_frames):
shift_clip_indices = []
accum_n_frames = 0
for start, end in clip_indices:
start, end = accum_n_frames, accum_n_frames + (end - start)
shift_clip_indices.append((start, end))
accum_n_frames += end - start
all_sample_indices = []
for i, ((start, end), (shift_start, shift_end), n_frames) in enumerate(
zip(clip_indices, shift_clip_indices, clip_n_frames)
):
indices = np.arange(start, end)
next_shift_start = (
shift_clip_indices[i + 1][0] if i < len(clip_indices) - 1 else shift_end
)
shift_sample_indices = (
np.linspace(shift_start, next_shift_start - 1, n_frames, dtype=int) - shift_start
)
sample_indices = indices[shift_sample_indices].tolist()
all_sample_indices.extend(sample_indices)
return all_sample_indices
def normalize_probabilities(
items: np.ndarray,
probs: np.ndarray,
masks: np.ndarray,
) -> Tuple[np.ndarray, np.ndarray]:
assert len(items), "Items must not be empty."
assert len(items) == len(masks) == len(probs), "Lengths must match."
# assert (items, np.ndarray), "isinstanceItems must be an np.ndarray."
assert isinstance(probs, np.ndarray), "Probs must be an np.ndarray."
assert isinstance(masks, np.ndarray), "Masks must be an np.ndarray."
assert masks.dtype == bool, "Masks must be boolean."
assert np.any(masks), "Masks must not be all False."
assert np.all(np.diff(masks.astype("int")) <= 0), "Masks must not break monotonicity."
ret_items = items[masks]
ret_probs = probs[masks]
# Accumulate the probabilities of infeasible items to the last feasible one.
ret_probs[-1] += probs[~masks].sum()
return ret_items, ret_probs
def adptive_sample_framelen_and_stride(
num_frames: int,
strides: np.ndarray,
strides_prob: np.ndarray,
frame_lengths: List[np.ndarray],
frame_lengths_prob: List[Optional[np.ndarray]],
) -> Dict[str, Any]:
"""Adaptively sample frame length and stride for a video.
Args:
num_frames: Number of frames in the current video.
strides: A list of strides.
strides_prob: The probability for each stride.
frame_lengths: The number of frames (sorted) to sample from at the current stride.
For example, `frame_length=10` at `stride=2` means that we need to have 20 frames.
When the number of frames to sample is infeasible, it will select the feasible frame
lengths and re-normalize the probability according to the feasible frames at hand.
For example, if `num_frames=10`, `frame_lengths[stride2]=[4, 5]`,
`frame_lengths[stride3]=[1, 3, 5]`, we can sample frame lengths 1, 2, and 5 at
`stride=2` (2, 4, and 10 frames) but only frame lengths 1, 3 at `stride=3`. In this
case, we will add the probability of `frame_length=5` at `stride=3` to `frame_length=3`
at `stride=3`, making it more likely to be selected.
frame_lengths_prob: The frame probabilities to sample from the corresponding frame lengths.
Defaults to None for uniform sampling.
Returns:
dictionary: A dictionary containing the selected frames and strides. if none is feasible,
it will raise an exception.
"""
assert len(strides) == len(strides_prob) == len(frame_lengths) == len(frame_lengths_prob)
# Prepare frame_lengths_mask for each stride.
frame_lengths_mask = [num_frames // s >= l for s, l in zip(strides, frame_lengths)]
# Prepare stride mask and prob.
strides_idxs = np.arange(len(strides))
strides_mask = np.array([np.any(mask) for mask in frame_lengths_mask])
assert np.any(strides_mask), (
f"Cannot sample frames={num_frames} "
+ f"from strides={strides} and lengths={frame_lengths}"
)
# Drop infeasible strides and normalize probability.
strides_idxs, strides_prob = normalize_probabilities(strides_idxs, strides_prob, strides_mask)
# Choose stride.
stride_idx = np.random.choice(strides_idxs, p=strides_prob)
stride = strides[stride_idx]
# Prepare frame_lengths mask and prob for the current stride.
lengths = frame_lengths[stride_idx]
lengths_mask = frame_lengths_mask[stride_idx]
lengths_prob = frame_lengths_prob[stride_idx]
if lengths_prob is None:
lengths_prob = np.full(len(lengths), 1.0 / len(lengths))
# Drop infeasible lengths and normalize probability.
lengths, lengths_prob = normalize_probabilities(lengths, lengths_prob, lengths_mask)
# Choose frame length.
length = np.random.choice(lengths, p=lengths_prob)
return dict(stride=stride, frame_length=length)