| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| """ |
| Frame samplers. |
| |
| TODO: 可能需要写一下满足自定义需求的frame sampler |
| """ |
|
|
| from abc import ABC, abstractmethod |
| from dataclasses import dataclass |
| from typing import Any, Dict, List, Literal, NamedTuple, Optional, Tuple, Union |
| import numpy as np |
|
|
|
|
| class FrameSamplerOutput(NamedTuple): |
| """ |
| Return indices for frame decoding, |
| and optionally additional information to return to user. |
| """ |
|
|
| indices: List[int] |
| additional_info: Dict[str, Any] = {} |
|
|
|
|
| class FrameSampler(ABC): |
| """ |
| Frame sampler base class. |
| |
| Child class must implement __call__ method to return the decoding indices. |
| Or raise if the video cannot be sampled (e.g. too short, etc.) |
| """ |
|
|
| @abstractmethod |
| def __call__(self, num_frames: int) -> FrameSamplerOutput: |
| raise NotImplementedError |
|
|
|
|
| class AllFrameSampler(FrameSampler): |
| """ |
| All frame sampler. Returns all frames in a video. |
| """ |
|
|
| def __call__(self, num_frames: int) -> FrameSamplerOutput: |
| return FrameSamplerOutput(indices=list(range(num_frames))) |
|
|
| class OnlyFirstFrameSampler: |
| """ |
| Only first frame sampler. Returns only the first frame of a video. |
| """ |
|
|
| def __call__(self, frames_info: Dict[str, int], **kwargs) -> FrameSamplerOutput: |
| return FrameSamplerOutput(indices=[0]) |
|
|
| class FixedFrameSampler: |
| """ |
| 固定帧数采样器(上/下采样统一算法): |
| - 接受包含 start_frame, end_frame, total_frames 的 frames_info dict; |
| - 对任意 total_frames ≥ 1,总是返回长度为 num_frames 的帧编号列表; |
| - 保证首尾对应 start_frame 和 end_frame - 1,内部等距离分布; |
| - 当 total_frames < num_frames 时会重复索引,如 [0,1,2] → [0,0,1,1,2,2]。 |
| """ |
| def __init__(self, num_frames: int): |
| if num_frames < 1: |
| raise ValueError("num_frames must be ≥ 1") |
| self.num_frames = num_frames |
|
|
| def __call__(self, frames_info: Dict[str, int]) -> List[int]: |
| """ |
| 参数: |
| frames_info: 包含 'start_frame', 'end_frame', 'total_frames' 的字典 |
| 返回: |
| List[int]: 采样后的全局帧编号列表,长度恒为 num_frames |
| """ |
| start = frames_info.get('start_frame') |
| total = frames_info.get('total_frames') |
| end = frames_info.get('end_frame') |
| if start is None or total is None or end is None: |
| raise ValueError("frames_info must contain 'start_frame', 'end_frame', and 'total_frames'") |
| if total < 1: |
| raise ValueError("total_frames must be ≥ 1") |
| |
| rel_indices = self._get_indices(total) |
| |
| indices = [min(start + idx, end - 1) for idx in rel_indices] |
|
|
| return FrameSamplerOutput( |
| indices=indices, |
| additional_info={ |
| "start_frame": start, |
| "end_frame": end, |
| "total_frames": total, |
| }, |
| ) |
|
|
| def _get_indices(self, total: int) -> List[int]: |
| |
| if self.num_frames == 1: |
| return [0] |
| |
| return [ |
| int(round(i * (total - 1) / (self.num_frames - 1))) |
| for i in range(self.num_frames) |
| ] |
|
|
|
|
| class ConsecutiveFrameSampler(FrameSampler): |
| """ |
| Adaptive frame sampler. |
| |
| Arguments: |
| stride: frame skip. |
| For example, 1 denotes no skip. 2 denotes select every other frame. 3 |
| denotes select every third frame. When a list is given, stride is randomly |
| chosen with even probability. However, user may set it to [1,1,2] to |
| denote 1 with 66% probability and 2 with 33% proability. |
| clip: clip location. |
| "center": clip video at the center. |
| "uniform": clip video uniformly at random. |
| jitter: jitter to the location. |
| Only applicable when clip is "center". |
| The value is the stdev of the normal distribution to shift the index. |
| """ |
|
|
| def __init__( |
| self, |
| strides: Union[int, List[int]] = 1, |
| temporal: int = 4, |
| clip: Literal["center", "uniform"] = "uniform", |
| jitter: float = 0.0, |
| ): |
| strides = [strides] if isinstance(strides, int) else strides |
| assert len(strides) > 0 |
| self.strides = np.array(strides) |
| self.temporal = temporal |
| self.clip = clip |
| self.jitter = jitter |
|
|
| def __call__(self, frames_info: Dict[str, int]) -> FrameSamplerOutput: |
|
|
| start_frame = frames_info["start_frame"] |
| end_frame = frames_info["end_frame"] |
| num_frames = frames_info["total_frames"] |
|
|
| stride = np.random.choice(self.strides) |
|
|
| frames = end_frame - start_frame |
| length = frames // stride |
|
|
| |
| def _max_kn_plus_1(length, k): |
| if length < 1: |
| raise ValueError("Length must be at least 1.") |
| n = (length - 1) // k |
| return k * n + 1 |
|
|
| length = _max_kn_plus_1(length, self.temporal) |
|
|
| |
| min_start_index = start_frame |
| max_start_index = end_frame - 1 - stride * (length - 1) |
|
|
| mid_start_index = round((min_start_index + max_start_index) / 2) |
| jitter = round(np.random.normal(loc=0, scale=self.jitter)) |
|
|
| if self.clip == "head": |
| start_index = min_start_index |
| elif self.clip == "tail": |
| start_index = max_start_index |
| elif self.clip == "center": |
| start_index = mid_start_index + jitter |
| elif self.clip == "uniform": |
| start_index = np.random.randint(min_start_index, max_start_index + 1) |
| else: |
| raise NotImplementedError |
|
|
| start_index = np.clip(start_index, min_start_index, max_start_index) |
|
|
| |
| indices = np.arange(start_index, start_index + length * stride, stride) |
|
|
| |
| return FrameSamplerOutput( |
| indices=indices.tolist(), |
| additional_info={ |
| "stride": stride, |
| "start_frame": start_index, |
| "end_frame": start_index + length * stride, |
| "total_frames": num_frames, |
| }, |
| ) |
|
|
|
|
| class AdaptiveFrameSampler(FrameSampler): |
| """ |
| Adaptive frame sampler. |
| |
| Arguments: |
| length: frame length to return. |
| For example, [5,10] denotes to always return 5 frames or 10 frames. |
| It will choose the longest length that fits the original video. |
| For example, if the video is 9 frames total, it will clip to 5 frames. |
| stride: frame skip. |
| For example, 1 denotes no skip. 2 denotes select every other frame. 3 |
| denotes select every third frame. When a list is given, stride is randomly |
| chosen with even probability. However, user may set it to [1,1,2] to |
| denote 1 with 66% probability and 2 with 33% proability. |
| clip: clip location. |
| "center": clip video at the center. |
| "uniform": clip video uniformly at random. |
| jitter: jitter to the location. |
| Only applicable when clip is "center". |
| The value is the stdev of the normal distribution to shift the index. |
| """ |
|
|
| def __init__( |
| self, |
| lengths: Union[int, List[int]], |
| strides: Union[int, List[int]] = 1, |
| clip: Literal["center", "uniform"] = "uniform", |
| jitter: float = 0.0, |
| ): |
| lengths = [lengths] if isinstance(lengths, int) else lengths |
| strides = [strides] if isinstance(strides, int) else strides |
| assert len(lengths) > 0 |
| assert len(strides) > 0 |
| assert clip in ["center", "uniform"] |
| assert jitter >= 0 |
| self.lengths = np.array(lengths) |
| self.strides = np.array(strides) |
| self.clip = clip |
| self.jitter = jitter |
|
|
| def __call__( |
| self, |
| num_frames: int, |
| ) -> FrameSamplerOutput: |
| |
| |
| |
| valid_strides = np.any(num_frames // self.strides >= self.lengths.reshape(-1, 1), axis=0) |
| valid_strides = self.strides[valid_strides] |
| if valid_strides.size <= 0: |
| raise ValueError(f"Video is too short ({num_frames} frames).") |
| stride = np.random.choice(valid_strides) |
|
|
| |
| |
| valid_lengths = self.lengths[num_frames // stride >= self.lengths] |
| length = np.max(valid_lengths) |
|
|
| |
| min_start_index = 0 |
| max_start_index = num_frames - 1 - stride * (length - 1) |
| mid_start_index = round((min_start_index + max_start_index) / 2) |
| jitter = round(np.random.normal(loc=0, scale=self.jitter)) |
|
|
| if self.clip == "center": |
| start_index = mid_start_index + jitter |
| elif self.clip == "uniform": |
| start_index = np.random.randint(min_start_index, max_start_index + 1) |
| else: |
| raise NotImplementedError |
|
|
| start_index = np.clip(start_index, min_start_index, max_start_index) |
|
|
| |
| indices = np.arange(start_index, start_index + length * stride, stride) |
|
|
| |
| return FrameSamplerOutput( |
| indices=indices.tolist(), |
| additional_info={ |
| "stride": stride, |
| "start_frame": start_index, |
| "end_frame": start_index + length * stride, |
| "total_frames": num_frames, |
| }, |
| ) |
|
|
|
|
| @dataclass |
| class AdaptiveAdvancedFrameSamplerStrategy: |
| stride: int |
| stride_prob: float |
| frame_lengths: List[int] |
| frame_lengths_prob: Union[Literal["uniform", "harmonic"], List[float]] |
|
|
|
|
| class AdaptiveAdvancedFrameSampler(FrameSampler): |
| """ |
| Advanced adaptive frame sampler supports different frame lengths for different strides, |
| and supports probabilistic sampling of both the stride and the frame length. |
| |
| strategies: A list of strategies to sample from. |
| clip: clip location. |
| "center": clip video at the center. |
| "uniform": clip video uniformly at random. |
| jitter: jitter to the location. |
| Only applicable when clip is "center". |
| The value is the stdev of the normal distribution to shift the index. |
| """ |
|
|
| def __init__( |
| self, |
| strategies: List[AdaptiveAdvancedFrameSamplerStrategy], |
| clip: Literal["center", "uniform"] = "uniform", |
| jitter: float = 0.0, |
| ): |
| assert len(strategies) > 0, "Strategies must not be empty" |
| assert len({s.stride for s in strategies}) == len(strategies), "Strides cannot duplicate." |
| assert clip in ["center", "uniform"] |
| assert jitter >= 0 |
| self.clip = clip |
| self.jitter = jitter |
| self.strides = [] |
| self.strides_prob = [] |
| self.frame_lengths = [] |
| self.frame_lengths_prob = [] |
|
|
| for strategy in sorted(strategies, key=lambda s: s.stride): |
| |
| assert isinstance(strategy.stride, int), "Stride must be an integer." |
| assert strategy.stride > 0, "Stride must be a positive integer." |
| self.strides.append(strategy.stride) |
|
|
| |
| assert isinstance(strategy.stride_prob, (int, float)), "Stride prob is not int/float." |
| assert strategy.stride_prob >= 0, "Stride prob must be non-negative." |
| self.strides_prob.append(strategy.stride_prob) |
|
|
| |
| assert len(strategy.frame_lengths) > 0, "Frame lengths must not be empty." |
| frame_lengths = np.array(strategy.frame_lengths) |
| assert frame_lengths.dtype == int, "Frame lengths must be integers." |
| assert np.all(frame_lengths > 0), "Frame lengths must be positive integers." |
| frame_lengths_sorted_idx = np.argsort(frame_lengths) |
| frame_lengths = frame_lengths[frame_lengths_sorted_idx] |
| self.frame_lengths.append(frame_lengths) |
|
|
| |
| if strategy.frame_lengths_prob == "uniform": |
| |
| frame_lengths_prob = np.full(len(frame_lengths), 1.0 / len(frame_lengths)) |
| elif strategy.frame_lengths_prob == "harmonic": |
| |
| frame_lengths_prob = np.flip(1 / np.arange(1, len(frame_lengths) + 1)) |
| elif isinstance(strategy.frame_lengths_prob, list): |
| frame_lengths_prob = np.array(strategy.frame_lengths_prob) |
| frame_lengths_prob = frame_lengths_prob[frame_lengths_sorted_idx] |
| else: |
| raise NotImplementedError |
| assert len(frame_lengths_prob) == len(frame_lengths), "Frame lengths prob mismatch." |
| assert np.all(frame_lengths_prob >= 0), "Frame lengths prob must not be negative." |
| assert frame_lengths_prob.sum() > 0, "Frame lengths prob must not be all zeros." |
| frame_lengths_prob /= frame_lengths_prob.sum() |
| self.frame_lengths_prob.append(frame_lengths_prob) |
|
|
| self.strides = np.array(self.strides) |
| self.strides_prob = np.array(self.strides_prob) |
| assert self.strides_prob.sum() > 0, "Strides prob must not be all zeros." |
| self.strides_prob /= self.strides_prob.sum() |
|
|
| def __call__(self, num_frames: int): |
| sample_result = adptive_sample_framelen_and_stride( |
| num_frames=num_frames, |
| strides=self.strides, |
| strides_prob=self.strides_prob, |
| frame_lengths=self.frame_lengths, |
| frame_lengths_prob=self.frame_lengths_prob, |
| ) |
|
|
| stride = sample_result["stride"] |
| length = sample_result["frame_length"] |
|
|
| |
| min_start_index = 0 |
| max_start_index = num_frames - 1 - stride * (length - 1) |
| mid_start_index = round((min_start_index + max_start_index) / 2) |
| jitter = round(np.random.normal(loc=0, scale=self.jitter)) |
|
|
| if self.clip == "center": |
| start_index = mid_start_index + jitter |
| elif self.clip == "uniform": |
| start_index = np.random.randint(min_start_index, max_start_index + 1) |
| else: |
| raise NotImplementedError |
|
|
| start_index = np.clip(start_index, min_start_index, max_start_index) |
|
|
| |
| indices = np.arange(start_index, start_index + length * stride, stride) |
|
|
| |
| return FrameSamplerOutput( |
| indices=indices.tolist(), |
| additional_info={ |
| "stride": stride, |
| "start_frame": start_index, |
| "end_frame": start_index + length * stride, |
| "total_frames": num_frames, |
| }, |
| ) |
|
|
|
|
| class MultiClipsFrameSampler(FrameSampler): |
| """ |
| multi clips frame sampler. |
| |
| Arguments: |
| temporal: downsample factor on temporal |
| sample_fps: fps of sampled frames |
| truncate: whether to truncate by max duration of the video (default = false, already truncate in clip_indices) |
| max_duration: truncate by max duration of the video |
| """ |
|
|
| def __init__( |
| self, |
| temporal: int = 4, |
| sample_fps: int = 12, |
| truncate: bool = False, |
| max_duration: int = 12, |
| length_type: Literal["kn", "kn+1"] = "kn+1", |
| assert_seconds: bool = True, |
| ): |
| self.temporal = temporal |
| self.sample_fps = sample_fps |
| self.truncate = truncate |
| self.max_duration = max_duration |
| self.length_type = length_type |
| self.assert_seconds = assert_seconds |
|
|
| def __call__(self, frames_info: Dict[str, int]) -> FrameSamplerOutput: |
|
|
| clip_indices = frames_info["clip_indices"] |
| origin_fps = frames_info["fps"] |
|
|
| if self.truncate: |
| clip_indices = self.truncate_to_bucket(clip_indices, origin_fps) |
|
|
| if self.assert_seconds: |
| duration_sec = int(round(sum([(end - start) / origin_fps for start, end in clip_indices]))) |
| if not self.truncate: |
| duration_sec = min(duration_sec, self.max_duration) |
| duration = int(round(duration_sec)) |
|
|
| n_frames = duration * self.sample_fps |
| if self.length_type == "kn+1": |
| n_frames += 1 |
| else: |
| duration = sum([(end - start) / origin_fps for start, end in clip_indices]) |
| if not self.truncate: |
| duration = min(duration, self.max_duration) |
| n_frames = int(round(duration * self.sample_fps)) |
| if self.length_type == "kn+1": |
| if n_frames % self.temporal != 0: |
| n_frames = n_frames // self.temporal * self.temporal + 1 |
| else: |
| n_frames = n_frames // self.temporal * self.temporal + 1 - self.temporal |
| clip_n_frames = self.split_n_frames_by_clip(n_frames, clip_indices) |
| sample_indices = self.sample_frame_indices(clip_indices, clip_n_frames) |
|
|
| clip_n_latent_frames = [(n + self.temporal - 1) // self.temporal for n in clip_n_frames] |
|
|
| return FrameSamplerOutput( |
| indices=sample_indices, |
| additional_info={ |
| "clip_n_frames": clip_n_frames, |
| "clip_n_latent_frames": clip_n_latent_frames, |
| }, |
| ) |
|
|
| def truncate_to_bucket(self, clip_indices, fps): |
| clip_indices = [tuple(index) for index in clip_indices] |
| durations = [] |
| for start, end in clip_indices: |
| durations.append((end - start) / fps) |
| duration = sum(durations) |
| max_duration = min(int(duration), self.max_duration) |
| cutoff = duration - max_duration |
| if cutoff > 0: |
| if durations[-1] - cutoff > durations[0] - cutoff: |
| start, end = clip_indices[-1] |
| end = min(round((durations[-1] - cutoff) * fps), end) + start |
| clip_indices[-1] = (start, end) |
| else: |
| start, end = clip_indices[0] |
| start = max(end - round((durations[0] - cutoff) * fps), start) |
| clip_indices[0] = (start, end) |
| return clip_indices |
|
|
| def split_n_frames_by_clip(self, n_frames, clip_indices): |
| n_latent_frames = n_frames // self.temporal |
| clip_lengths = [(end - start) for start, end in clip_indices] |
| clip_n_latent_frames = [int(l / sum(clip_lengths) * n_latent_frames) for l in clip_lengths] |
| n_remains = n_latent_frames - sum(clip_n_latent_frames) |
| for i in range(n_remains): |
| clip_n_latent_frames[i] += 1 |
| clip_n_frames = [n * self.temporal for n in clip_n_latent_frames] |
| if self.length_type == "kn+1": |
| clip_n_frames[0] += 1 |
| return clip_n_frames |
|
|
| def sample_frame_indices(self, clip_indices, clip_n_frames): |
| shift_clip_indices = [] |
| accum_n_frames = 0 |
| for start, end in clip_indices: |
| start, end = accum_n_frames, accum_n_frames + (end - start) |
| shift_clip_indices.append((start, end)) |
| accum_n_frames += end - start |
|
|
| all_sample_indices = [] |
| for i, ((start, end), (shift_start, shift_end), n_frames) in enumerate( |
| zip(clip_indices, shift_clip_indices, clip_n_frames) |
| ): |
| indices = np.arange(start, end) |
| next_shift_start = ( |
| shift_clip_indices[i + 1][0] if i < len(clip_indices) - 1 else shift_end |
| ) |
| shift_sample_indices = ( |
| np.linspace(shift_start, next_shift_start - 1, n_frames, dtype=int) - shift_start |
| ) |
| sample_indices = indices[shift_sample_indices].tolist() |
| all_sample_indices.extend(sample_indices) |
|
|
| return all_sample_indices |
|
|
|
|
| def normalize_probabilities( |
| items: np.ndarray, |
| probs: np.ndarray, |
| masks: np.ndarray, |
| ) -> Tuple[np.ndarray, np.ndarray]: |
| assert len(items), "Items must not be empty." |
| assert len(items) == len(masks) == len(probs), "Lengths must match." |
| |
| assert isinstance(probs, np.ndarray), "Probs must be an np.ndarray." |
| assert isinstance(masks, np.ndarray), "Masks must be an np.ndarray." |
| assert masks.dtype == bool, "Masks must be boolean." |
| assert np.any(masks), "Masks must not be all False." |
| assert np.all(np.diff(masks.astype("int")) <= 0), "Masks must not break monotonicity." |
|
|
| ret_items = items[masks] |
| ret_probs = probs[masks] |
|
|
| |
| ret_probs[-1] += probs[~masks].sum() |
|
|
| return ret_items, ret_probs |
|
|
|
|
| def adptive_sample_framelen_and_stride( |
| num_frames: int, |
| strides: np.ndarray, |
| strides_prob: np.ndarray, |
| frame_lengths: List[np.ndarray], |
| frame_lengths_prob: List[Optional[np.ndarray]], |
| ) -> Dict[str, Any]: |
| """Adaptively sample frame length and stride for a video. |
| |
| Args: |
| num_frames: Number of frames in the current video. |
| strides: A list of strides. |
| strides_prob: The probability for each stride. |
| frame_lengths: The number of frames (sorted) to sample from at the current stride. |
| For example, `frame_length=10` at `stride=2` means that we need to have 20 frames. |
| When the number of frames to sample is infeasible, it will select the feasible frame |
| lengths and re-normalize the probability according to the feasible frames at hand. |
| For example, if `num_frames=10`, `frame_lengths[stride2]=[4, 5]`, |
| `frame_lengths[stride3]=[1, 3, 5]`, we can sample frame lengths 1, 2, and 5 at |
| `stride=2` (2, 4, and 10 frames) but only frame lengths 1, 3 at `stride=3`. In this |
| case, we will add the probability of `frame_length=5` at `stride=3` to `frame_length=3` |
| at `stride=3`, making it more likely to be selected. |
| frame_lengths_prob: The frame probabilities to sample from the corresponding frame lengths. |
| Defaults to None for uniform sampling. |
| Returns: |
| dictionary: A dictionary containing the selected frames and strides. if none is feasible, |
| it will raise an exception. |
| """ |
| assert len(strides) == len(strides_prob) == len(frame_lengths) == len(frame_lengths_prob) |
|
|
| |
| frame_lengths_mask = [num_frames // s >= l for s, l in zip(strides, frame_lengths)] |
|
|
| |
| strides_idxs = np.arange(len(strides)) |
| strides_mask = np.array([np.any(mask) for mask in frame_lengths_mask]) |
| assert np.any(strides_mask), ( |
| f"Cannot sample frames={num_frames} " |
| + f"from strides={strides} and lengths={frame_lengths}" |
| ) |
|
|
| |
| strides_idxs, strides_prob = normalize_probabilities(strides_idxs, strides_prob, strides_mask) |
|
|
| |
| stride_idx = np.random.choice(strides_idxs, p=strides_prob) |
| stride = strides[stride_idx] |
|
|
| |
| lengths = frame_lengths[stride_idx] |
| lengths_mask = frame_lengths_mask[stride_idx] |
| lengths_prob = frame_lengths_prob[stride_idx] |
| if lengths_prob is None: |
| lengths_prob = np.full(len(lengths), 1.0 / len(lengths)) |
|
|
| |
| lengths, lengths_prob = normalize_probabilities(lengths, lengths_prob, lengths_mask) |
|
|
| |
| length = np.random.choice(lengths, p=lengths_prob) |
| return dict(stride=stride, frame_length=length) |
|
|