File size: 25,134 Bytes
8b306b3 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 | # Copyright (c) 2025 ByteDance Ltd. and/or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# coding: utf-8
"""
Frame samplers.
TODO: 可能需要写一下满足自定义需求的frame sampler
"""
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Dict, List, Literal, NamedTuple, Optional, Tuple, Union
import numpy as np
class FrameSamplerOutput(NamedTuple):
"""
Return indices for frame decoding,
and optionally additional information to return to user.
"""
indices: List[int]
additional_info: Dict[str, Any] = {}
class FrameSampler(ABC):
"""
Frame sampler base class.
Child class must implement __call__ method to return the decoding indices.
Or raise if the video cannot be sampled (e.g. too short, etc.)
"""
@abstractmethod
def __call__(self, num_frames: int) -> FrameSamplerOutput:
raise NotImplementedError
class AllFrameSampler(FrameSampler):
"""
All frame sampler. Returns all frames in a video.
"""
def __call__(self, num_frames: int) -> FrameSamplerOutput:
return FrameSamplerOutput(indices=list(range(num_frames)))
class OnlyFirstFrameSampler:
"""
Only first frame sampler. Returns only the first frame of a video.
"""
def __call__(self, frames_info: Dict[str, int], **kwargs) -> FrameSamplerOutput:
return FrameSamplerOutput(indices=[0])
class FixedFrameSampler:
"""
固定帧数采样器(上/下采样统一算法):
- 接受包含 start_frame, end_frame, total_frames 的 frames_info dict;
- 对任意 total_frames ≥ 1,总是返回长度为 num_frames 的帧编号列表;
- 保证首尾对应 start_frame 和 end_frame - 1,内部等距离分布;
- 当 total_frames < num_frames 时会重复索引,如 [0,1,2] → [0,0,1,1,2,2]。
"""
def __init__(self, num_frames: int):
if num_frames < 1:
raise ValueError("num_frames must be ≥ 1")
self.num_frames = num_frames
def __call__(self, frames_info: Dict[str, int]) -> List[int]:
"""
参数:
frames_info: 包含 'start_frame', 'end_frame', 'total_frames' 的字典
返回:
List[int]: 采样后的全局帧编号列表,长度恒为 num_frames
"""
start = frames_info.get('start_frame')
total = frames_info.get('total_frames')
end = frames_info.get('end_frame')
if start is None or total is None or end is None:
raise ValueError("frames_info must contain 'start_frame', 'end_frame', and 'total_frames'")
if total < 1:
raise ValueError("total_frames must be ≥ 1")
# 计算相对索引
rel_indices = self._get_indices(total)
# 转换为全局并确保不越界
indices = [min(start + idx, end - 1) for idx in rel_indices]
return FrameSamplerOutput(
indices=indices,
additional_info={
"start_frame": start,
"end_frame": end,
"total_frames": total,
},
)
def _get_indices(self, total: int) -> List[int]:
# 单帧特殊处理
if self.num_frames == 1:
return [0]
# 统一采样公式,包括上采样和下采样场景
return [
int(round(i * (total - 1) / (self.num_frames - 1)))
for i in range(self.num_frames)
]
class ConsecutiveFrameSampler(FrameSampler):
"""
Adaptive frame sampler.
Arguments:
stride: frame skip.
For example, 1 denotes no skip. 2 denotes select every other frame. 3
denotes select every third frame. When a list is given, stride is randomly
chosen with even probability. However, user may set it to [1,1,2] to
denote 1 with 66% probability and 2 with 33% proability.
clip: clip location.
"center": clip video at the center.
"uniform": clip video uniformly at random.
jitter: jitter to the location.
Only applicable when clip is "center".
The value is the stdev of the normal distribution to shift the index.
"""
def __init__(
self,
strides: Union[int, List[int]] = 1,
temporal: int = 4,
clip: Literal["center", "uniform"] = "uniform",
jitter: float = 0.0,
):
strides = [strides] if isinstance(strides, int) else strides
assert len(strides) > 0
self.strides = np.array(strides)
self.temporal = temporal
self.clip = clip
self.jitter = jitter
def __call__(self, frames_info: Dict[str, int]) -> FrameSamplerOutput:
start_frame = frames_info["start_frame"]
end_frame = frames_info["end_frame"]
num_frames = frames_info["total_frames"]
stride = np.random.choice(self.strides)
frames = end_frame - start_frame
length = frames // stride
# Calculate the maximum integer of the form kn + 1 that does not exceed the given length.
def _max_kn_plus_1(length, k):
if length < 1:
raise ValueError("Length must be at least 1.")
n = (length - 1) // k
return k * n + 1
length = _max_kn_plus_1(length, self.temporal)
# Choose start index.
min_start_index = start_frame
max_start_index = end_frame - 1 - stride * (length - 1)
mid_start_index = round((min_start_index + max_start_index) / 2)
jitter = round(np.random.normal(loc=0, scale=self.jitter))
if self.clip == "head":
start_index = min_start_index
elif self.clip == "tail":
start_index = max_start_index
elif self.clip == "center":
start_index = mid_start_index + jitter
elif self.clip == "uniform":
start_index = np.random.randint(min_start_index, max_start_index + 1)
else:
raise NotImplementedError
start_index = np.clip(start_index, min_start_index, max_start_index)
# Compute indices
indices = np.arange(start_index, start_index + length * stride, stride)
# Return indices and additional information to return to user.
return FrameSamplerOutput(
indices=indices.tolist(),
additional_info={
"stride": stride,
"start_frame": start_index,
"end_frame": start_index + length * stride,
"total_frames": num_frames,
},
)
class AdaptiveFrameSampler(FrameSampler):
"""
Adaptive frame sampler.
Arguments:
length: frame length to return.
For example, [5,10] denotes to always return 5 frames or 10 frames.
It will choose the longest length that fits the original video.
For example, if the video is 9 frames total, it will clip to 5 frames.
stride: frame skip.
For example, 1 denotes no skip. 2 denotes select every other frame. 3
denotes select every third frame. When a list is given, stride is randomly
chosen with even probability. However, user may set it to [1,1,2] to
denote 1 with 66% probability and 2 with 33% proability.
clip: clip location.
"center": clip video at the center.
"uniform": clip video uniformly at random.
jitter: jitter to the location.
Only applicable when clip is "center".
The value is the stdev of the normal distribution to shift the index.
"""
def __init__(
self,
lengths: Union[int, List[int]],
strides: Union[int, List[int]] = 1,
clip: Literal["center", "uniform"] = "uniform",
jitter: float = 0.0,
):
lengths = [lengths] if isinstance(lengths, int) else lengths
strides = [strides] if isinstance(strides, int) else strides
assert len(lengths) > 0
assert len(strides) > 0
assert clip in ["center", "uniform"]
assert jitter >= 0
self.lengths = np.array(lengths)
self.strides = np.array(strides)
self.clip = clip
self.jitter = jitter
def __call__(
self,
num_frames: int,
) -> FrameSamplerOutput:
# Choose stride.
# Drop strides that are too long for this video.
# Then randomly choose a valid stride.
valid_strides = np.any(num_frames // self.strides >= self.lengths.reshape(-1, 1), axis=0)
valid_strides = self.strides[valid_strides]
if valid_strides.size <= 0:
raise ValueError(f"Video is too short ({num_frames} frames).")
stride = np.random.choice(valid_strides)
# Choose length.
# Pick the max length that can fit the video under the current stride.
valid_lengths = self.lengths[num_frames // stride >= self.lengths]
length = np.max(valid_lengths)
# Choose start index.
min_start_index = 0
max_start_index = num_frames - 1 - stride * (length - 1)
mid_start_index = round((min_start_index + max_start_index) / 2)
jitter = round(np.random.normal(loc=0, scale=self.jitter))
if self.clip == "center":
start_index = mid_start_index + jitter
elif self.clip == "uniform":
start_index = np.random.randint(min_start_index, max_start_index + 1)
else:
raise NotImplementedError
start_index = np.clip(start_index, min_start_index, max_start_index)
# Compute indices
indices = np.arange(start_index, start_index + length * stride, stride)
# Return indices and additional information to return to user.
return FrameSamplerOutput(
indices=indices.tolist(),
additional_info={
"stride": stride,
"start_frame": start_index,
"end_frame": start_index + length * stride,
"total_frames": num_frames,
},
)
@dataclass
class AdaptiveAdvancedFrameSamplerStrategy:
stride: int
stride_prob: float
frame_lengths: List[int]
frame_lengths_prob: Union[Literal["uniform", "harmonic"], List[float]]
class AdaptiveAdvancedFrameSampler(FrameSampler):
"""
Advanced adaptive frame sampler supports different frame lengths for different strides,
and supports probabilistic sampling of both the stride and the frame length.
strategies: A list of strategies to sample from.
clip: clip location.
"center": clip video at the center.
"uniform": clip video uniformly at random.
jitter: jitter to the location.
Only applicable when clip is "center".
The value is the stdev of the normal distribution to shift the index.
"""
def __init__(
self,
strategies: List[AdaptiveAdvancedFrameSamplerStrategy],
clip: Literal["center", "uniform"] = "uniform",
jitter: float = 0.0,
):
assert len(strategies) > 0, "Strategies must not be empty"
assert len({s.stride for s in strategies}) == len(strategies), "Strides cannot duplicate."
assert clip in ["center", "uniform"]
assert jitter >= 0
self.clip = clip
self.jitter = jitter
self.strides = []
self.strides_prob = []
self.frame_lengths = []
self.frame_lengths_prob = []
for strategy in sorted(strategies, key=lambda s: s.stride):
# Validate strides.
assert isinstance(strategy.stride, int), "Stride must be an integer."
assert strategy.stride > 0, "Stride must be a positive integer."
self.strides.append(strategy.stride)
# Assign strides_prob.
assert isinstance(strategy.stride_prob, (int, float)), "Stride prob is not int/float."
assert strategy.stride_prob >= 0, "Stride prob must be non-negative."
self.strides_prob.append(strategy.stride_prob)
# Assign frame lengths, sort by value.
assert len(strategy.frame_lengths) > 0, "Frame lengths must not be empty."
frame_lengths = np.array(strategy.frame_lengths)
assert frame_lengths.dtype == int, "Frame lengths must be integers."
assert np.all(frame_lengths > 0), "Frame lengths must be positive integers."
frame_lengths_sorted_idx = np.argsort(frame_lengths)
frame_lengths = frame_lengths[frame_lengths_sorted_idx]
self.frame_lengths.append(frame_lengths)
# Assign frame lengths prob, apply the sorting to prob as well.
if strategy.frame_lengths_prob == "uniform":
# e.g. [0.2, 0.2, 0.2, 0.2, 0.2]
frame_lengths_prob = np.full(len(frame_lengths), 1.0 / len(frame_lengths))
elif strategy.frame_lengths_prob == "harmonic":
# e.g. [0.2, 0.25, 0.33, 0.5, 1]
frame_lengths_prob = np.flip(1 / np.arange(1, len(frame_lengths) + 1))
elif isinstance(strategy.frame_lengths_prob, list):
frame_lengths_prob = np.array(strategy.frame_lengths_prob)
frame_lengths_prob = frame_lengths_prob[frame_lengths_sorted_idx]
else:
raise NotImplementedError
assert len(frame_lengths_prob) == len(frame_lengths), "Frame lengths prob mismatch."
assert np.all(frame_lengths_prob >= 0), "Frame lengths prob must not be negative."
assert frame_lengths_prob.sum() > 0, "Frame lengths prob must not be all zeros."
frame_lengths_prob /= frame_lengths_prob.sum()
self.frame_lengths_prob.append(frame_lengths_prob)
self.strides = np.array(self.strides)
self.strides_prob = np.array(self.strides_prob)
assert self.strides_prob.sum() > 0, "Strides prob must not be all zeros."
self.strides_prob /= self.strides_prob.sum()
def __call__(self, num_frames: int):
sample_result = adptive_sample_framelen_and_stride(
num_frames=num_frames,
strides=self.strides,
strides_prob=self.strides_prob,
frame_lengths=self.frame_lengths,
frame_lengths_prob=self.frame_lengths_prob,
)
stride = sample_result["stride"]
length = sample_result["frame_length"]
# Choose start index.
min_start_index = 0
max_start_index = num_frames - 1 - stride * (length - 1)
mid_start_index = round((min_start_index + max_start_index) / 2)
jitter = round(np.random.normal(loc=0, scale=self.jitter))
if self.clip == "center":
start_index = mid_start_index + jitter
elif self.clip == "uniform":
start_index = np.random.randint(min_start_index, max_start_index + 1)
else:
raise NotImplementedError
start_index = np.clip(start_index, min_start_index, max_start_index)
# Compute indices
indices = np.arange(start_index, start_index + length * stride, stride)
# Return indices and additional information to return to user.
return FrameSamplerOutput(
indices=indices.tolist(),
additional_info={
"stride": stride,
"start_frame": start_index,
"end_frame": start_index + length * stride,
"total_frames": num_frames,
},
)
class MultiClipsFrameSampler(FrameSampler):
"""
multi clips frame sampler.
Arguments:
temporal: downsample factor on temporal
sample_fps: fps of sampled frames
truncate: whether to truncate by max duration of the video (default = false, already truncate in clip_indices)
max_duration: truncate by max duration of the video
"""
def __init__(
self,
temporal: int = 4,
sample_fps: int = 12,
truncate: bool = False,
max_duration: int = 12,
length_type: Literal["kn", "kn+1"] = "kn+1",
assert_seconds: bool = True,
):
self.temporal = temporal
self.sample_fps = sample_fps
self.truncate = truncate
self.max_duration = max_duration
self.length_type = length_type
self.assert_seconds = assert_seconds
def __call__(self, frames_info: Dict[str, int]) -> FrameSamplerOutput:
clip_indices = frames_info["clip_indices"]
origin_fps = frames_info["fps"]
if self.truncate:
clip_indices = self.truncate_to_bucket(clip_indices, origin_fps)
if self.assert_seconds:
duration_sec = int(round(sum([(end - start) / origin_fps for start, end in clip_indices])))
if not self.truncate: # 新增:即使不截段也限制总时长
duration_sec = min(duration_sec, self.max_duration)
duration = int(round(duration_sec))
n_frames = duration * self.sample_fps
if self.length_type == "kn+1":
n_frames += 1
else:
duration = sum([(end - start) / origin_fps for start, end in clip_indices])
if not self.truncate: # 新增
duration = min(duration, self.max_duration)
n_frames = int(round(duration * self.sample_fps))
if self.length_type == "kn+1":
if n_frames % self.temporal != 0:
n_frames = n_frames // self.temporal * self.temporal + 1
else:
n_frames = n_frames // self.temporal * self.temporal + 1 - self.temporal
clip_n_frames = self.split_n_frames_by_clip(n_frames, clip_indices)
sample_indices = self.sample_frame_indices(clip_indices, clip_n_frames)
clip_n_latent_frames = [(n + self.temporal - 1) // self.temporal for n in clip_n_frames]
return FrameSamplerOutput(
indices=sample_indices,
additional_info={
"clip_n_frames": clip_n_frames,
"clip_n_latent_frames": clip_n_latent_frames,
},
)
def truncate_to_bucket(self, clip_indices, fps):
clip_indices = [tuple(index) for index in clip_indices]
durations = []
for start, end in clip_indices:
durations.append((end - start) / fps)
duration = sum(durations)
max_duration = min(int(duration), self.max_duration)
cutoff = duration - max_duration
if cutoff > 0:
if durations[-1] - cutoff > durations[0] - cutoff: # 截掉尾部
start, end = clip_indices[-1]
end = min(round((durations[-1] - cutoff) * fps), end) + start
clip_indices[-1] = (start, end)
else:
start, end = clip_indices[0]
start = max(end - round((durations[0] - cutoff) * fps), start)
clip_indices[0] = (start, end)
return clip_indices
def split_n_frames_by_clip(self, n_frames, clip_indices):
n_latent_frames = n_frames // self.temporal
clip_lengths = [(end - start) for start, end in clip_indices]
clip_n_latent_frames = [int(l / sum(clip_lengths) * n_latent_frames) for l in clip_lengths]
n_remains = n_latent_frames - sum(clip_n_latent_frames)
for i in range(n_remains):
clip_n_latent_frames[i] += 1
clip_n_frames = [n * self.temporal for n in clip_n_latent_frames]
if self.length_type == "kn+1":
clip_n_frames[0] += 1
return clip_n_frames
def sample_frame_indices(self, clip_indices, clip_n_frames):
shift_clip_indices = []
accum_n_frames = 0
for start, end in clip_indices:
start, end = accum_n_frames, accum_n_frames + (end - start)
shift_clip_indices.append((start, end))
accum_n_frames += end - start
all_sample_indices = []
for i, ((start, end), (shift_start, shift_end), n_frames) in enumerate(
zip(clip_indices, shift_clip_indices, clip_n_frames)
):
indices = np.arange(start, end)
next_shift_start = (
shift_clip_indices[i + 1][0] if i < len(clip_indices) - 1 else shift_end
)
shift_sample_indices = (
np.linspace(shift_start, next_shift_start - 1, n_frames, dtype=int) - shift_start
)
sample_indices = indices[shift_sample_indices].tolist()
all_sample_indices.extend(sample_indices)
return all_sample_indices
def normalize_probabilities(
items: np.ndarray,
probs: np.ndarray,
masks: np.ndarray,
) -> Tuple[np.ndarray, np.ndarray]:
assert len(items), "Items must not be empty."
assert len(items) == len(masks) == len(probs), "Lengths must match."
# assert (items, np.ndarray), "isinstanceItems must be an np.ndarray."
assert isinstance(probs, np.ndarray), "Probs must be an np.ndarray."
assert isinstance(masks, np.ndarray), "Masks must be an np.ndarray."
assert masks.dtype == bool, "Masks must be boolean."
assert np.any(masks), "Masks must not be all False."
assert np.all(np.diff(masks.astype("int")) <= 0), "Masks must not break monotonicity."
ret_items = items[masks]
ret_probs = probs[masks]
# Accumulate the probabilities of infeasible items to the last feasible one.
ret_probs[-1] += probs[~masks].sum()
return ret_items, ret_probs
def adptive_sample_framelen_and_stride(
num_frames: int,
strides: np.ndarray,
strides_prob: np.ndarray,
frame_lengths: List[np.ndarray],
frame_lengths_prob: List[Optional[np.ndarray]],
) -> Dict[str, Any]:
"""Adaptively sample frame length and stride for a video.
Args:
num_frames: Number of frames in the current video.
strides: A list of strides.
strides_prob: The probability for each stride.
frame_lengths: The number of frames (sorted) to sample from at the current stride.
For example, `frame_length=10` at `stride=2` means that we need to have 20 frames.
When the number of frames to sample is infeasible, it will select the feasible frame
lengths and re-normalize the probability according to the feasible frames at hand.
For example, if `num_frames=10`, `frame_lengths[stride2]=[4, 5]`,
`frame_lengths[stride3]=[1, 3, 5]`, we can sample frame lengths 1, 2, and 5 at
`stride=2` (2, 4, and 10 frames) but only frame lengths 1, 3 at `stride=3`. In this
case, we will add the probability of `frame_length=5` at `stride=3` to `frame_length=3`
at `stride=3`, making it more likely to be selected.
frame_lengths_prob: The frame probabilities to sample from the corresponding frame lengths.
Defaults to None for uniform sampling.
Returns:
dictionary: A dictionary containing the selected frames and strides. if none is feasible,
it will raise an exception.
"""
assert len(strides) == len(strides_prob) == len(frame_lengths) == len(frame_lengths_prob)
# Prepare frame_lengths_mask for each stride.
frame_lengths_mask = [num_frames // s >= l for s, l in zip(strides, frame_lengths)]
# Prepare stride mask and prob.
strides_idxs = np.arange(len(strides))
strides_mask = np.array([np.any(mask) for mask in frame_lengths_mask])
assert np.any(strides_mask), (
f"Cannot sample frames={num_frames} "
+ f"from strides={strides} and lengths={frame_lengths}"
)
# Drop infeasible strides and normalize probability.
strides_idxs, strides_prob = normalize_probabilities(strides_idxs, strides_prob, strides_mask)
# Choose stride.
stride_idx = np.random.choice(strides_idxs, p=strides_prob)
stride = strides[stride_idx]
# Prepare frame_lengths mask and prob for the current stride.
lengths = frame_lengths[stride_idx]
lengths_mask = frame_lengths_mask[stride_idx]
lengths_prob = frame_lengths_prob[stride_idx]
if lengths_prob is None:
lengths_prob = np.full(len(lengths), 1.0 / len(lengths))
# Drop infeasible lengths and normalize probability.
lengths, lengths_prob = normalize_probabilities(lengths, lengths_prob, lengths_mask)
# Choose frame length.
length = np.random.choice(lengths, p=lengths_prob)
return dict(stride=stride, frame_length=length)
|