| |
|
|
| import os |
| from logging import getLogger |
| from typing import Callable, List, Tuple |
|
|
| import torch |
| from PIL import Image |
| from torchcodec.decoders import VideoDecoder |
| from torchvision.utils import draw_bounding_boxes |
|
|
| from core.transforms.image_transform import ImageTransform |
|
|
| logger = getLogger() |
|
|
|
|
| def get_video_transform( |
| image_res: int = 224, |
| normalize_img: bool = True, |
| ) -> Tuple[Callable, int]: |
|
|
| transforms = VideoTransform( |
| size=image_res, |
| normalize_img=normalize_img, |
| ) |
|
|
| return transforms |
|
|
|
|
| class VideoTransform(ImageTransform): |
| def __init__( |
| self, |
| size: int = 224, |
| normalize_img: bool = True, |
| ) -> None: |
| super().__init__( |
| size=size, |
| normalize_img=normalize_img, |
| ) |
|
|
| def __call__(self, video_info: tuple, sampling_fps: int = 1): |
| video_path, max_frames, s, e, bbox_map = video_info |
|
|
| frames, sample_pos = self.load_video( |
| video_path, |
| max_frames=max_frames, |
| sampling_fps=sampling_fps, |
| s=s, |
| e=e, |
| ) |
|
|
| if bbox_map: |
| bbox_dict_map = {} |
| for idx_pos, pos in enumerate(sample_pos): |
| if str(pos) in bbox_map and bbox_map[str(pos)] is not None: |
| bbox_dict_map[idx_pos] = bbox_map[str(pos)] |
| if len(bbox_dict_map) > 0: |
| frames = self.draw_bounding_boxes(frames, bbox_dict_map) |
|
|
| return super()._transform_torch_tensor(frames) |
|
|
| def _process_multiple_images(self, image_paths: List[str]): |
| images = [Image.open(path).convert("RGB") for path in image_paths] |
| processed_images = [] |
| for image in images: |
| image, (w, h) = super().__call__(image) |
| processed_images.append(image) |
| processed_images = torch.cat(processed_images, dim=0) |
| return processed_images, (w, h) |
|
|
| def _process_multiple_images_pil(self, images: List[Image.Image]): |
| processed_images = [] |
| for image in images: |
| image, (w, h) = super().__call__(image) |
| processed_images.append(image) |
| processed_images = torch.cat(processed_images, dim=0) |
| return processed_images, (w, h) |
|
|
| def load_video(self, video_path, max_frames=16, sampling_fps=1, s=None, e=None): |
| """ |
| Loads a video from a given path and extracts frames based on specified parameters using OpenCV. |
| |
| Args: |
| video_path (str): The path to the video file. |
| max_frames (int, optional): The maximum number of frames to extract. Defaults to 16. |
| sampling_fps (int, optional): The sampling frame rate. Defaults to 1. |
| s (float, optional): The start time of the video in seconds. Defaults to None. |
| e (float, optional): The end time of the video in seconds. Defaults to None. |
| |
| Returns: |
| list: A list of frames extracted from the video. |
| """ |
|
|
| if not os.path.exists(video_path): |
| return |
|
|
| decoder = VideoDecoder(video_path, device="cpu") |
| decoder_metadata = decoder.metadata |
| fps = decoder_metadata.average_fps |
| total_frames = decoder_metadata.num_frames |
|
|
| start_frame = 0 if s is None else int(s * fps) |
| end_frame = total_frames - 1 if e is None else int(e * fps) |
| end_frame = min(end_frame, total_frames - 1) |
|
|
| if start_frame > end_frame: |
| start_frame, end_frame = end_frame, start_frame |
| elif start_frame == end_frame: |
| end_frame = start_frame + 1 |
|
|
| sample_fps = int(sampling_fps) |
| t_stride = int(round(float(fps) / sample_fps)) |
|
|
| all_pos = list(range(start_frame, end_frame + 1, t_stride)) |
| if len(all_pos) > max_frames: |
| sample_idxs = self.uniform_sample(len(all_pos), max_frames) |
| sample_pos = [all_pos[i] for i in sample_idxs] |
| elif len(all_pos) < max_frames: |
| total_clip_frames = end_frame - start_frame + 1 |
| if total_clip_frames < max_frames: |
| max_frames = total_clip_frames |
| sample_idxs = self.uniform_sample(total_clip_frames, max_frames) |
| sample_pos = [start_frame + idx for idx in sample_idxs] |
| else: |
| sample_pos = all_pos |
|
|
| all_frames = decoder.get_frames_at(indices=sample_pos) |
| all_frames = all_frames.data |
|
|
| return all_frames, sample_pos |
|
|
| def uniform_sample(self, m, n): |
| assert n <= m |
| stride = (m - 1) / (n - 1) if n > 1 else 0 |
| return [int(round(i * stride)) for i in range(n)] |
|
|
| def draw_bounding_boxes(self, frames, all_bboxes): |
| |
| N, _, _, _ = frames.shape |
| frames_with_bbox = ( |
| frames.clone() |
| ) |
| for i in range(N): |
| if i in all_bboxes: |
| bbox = all_bboxes[i] |
| |
| bbox_tensor = torch.tensor([bbox], dtype=torch.float32) |
| |
| frames_with_bbox[i] = draw_bounding_boxes( |
| frames_with_bbox[i], |
| boxes=bbox_tensor, |
| colors=(255, 0, 0), |
| width=4, |
| ) |
| return frames_with_bbox |
|
|