| import os |
| import torch |
| import argparse |
| from PIL import Image |
| from tqdm import tqdm |
| from transformers import AutoModel, AutoProcessor |
| import torch.multiprocessing as mp |
| from torch.utils.data import Dataset, DataLoader |
| import glob |
|
|
| |
| MODEL_ID = "/mnt/bn/ziyang-storage-cloudnative-hl/huggingface/siglip-so400m-patch14-384" |
| BATCH_SIZE = 1024 |
|
|
| def parse_arguments(): |
| """解析命令行参数""" |
| parser = argparse.ArgumentParser( |
| description="步骤 1: 使用 SigLIP (多GPU) 预计算所有视频帧的嵌入." |
| ) |
| parser.add_argument( |
| "--frames-path", |
| "-fp", |
| type=str, |
| required=True, |
| help="包含所有视频帧文件夹的基础目录的绝对路径。", |
| ) |
| parser.add_argument( |
| "--output-dir", |
| "-o", |
| type=str, |
| required=True, |
| help="用于保存嵌入.pt文件的输出目录路径。", |
| ) |
| return parser.parse_args() |
|
|
| class FrameDataset(Dataset): |
| """一个用于高效加载视频帧的PyTorch Dataset""" |
| def __init__(self, frame_paths): |
| self.frame_paths = frame_paths |
|
|
| def __len__(self): |
| return len(self.frame_paths) |
|
|
| def __getitem__(self, idx): |
| path = self.frame_paths[idx] |
| try: |
| image = Image.open(path).convert("RGB") |
| return image |
| except Exception: |
| return None |
|
|
| def collate_fn(batch): |
| """自定义collate函数,用于从批次中过滤掉None值""" |
| batch = [item for item in batch if item is not None] |
| if not batch: |
| return None |
| return batch |
|
|
| def process_video_chunk(args_tuple): |
| """ |
| 工作函数,用于在特定GPU上处理一批视频。 |
| """ |
| video_dirs_chunk, frames_base_path, gpu_id, output_dir = args_tuple |
| device = f"cuda:{gpu_id}" |
|
|
| |
| model = AutoModel.from_pretrained(MODEL_ID).to(device).eval() |
| processor = AutoProcessor.from_pretrained(MODEL_ID, use_fast=True) |
|
|
| progress_bar = tqdm(video_dirs_chunk, position=gpu_id, desc=f"GPU-{gpu_id}") |
|
|
| for video_dir in progress_bar: |
| video_name = os.path.basename(video_dir) |
| output_path = os.path.join(output_dir, f"{video_name}.pt") |
|
|
| |
| if os.path.exists(output_path): |
| progress_bar.write(f"Skipping {video_name}, embeddings already exist.") |
| continue |
| |
| frame_files = [f for f in os.listdir(video_dir) if f.endswith(".jpg")] |
| if not frame_files: |
| continue |
| frame_files.sort(key=lambda x: int(x.split("_")[1].split(".")[0])) |
| frame_paths = [os.path.join(video_dir, f) for f in frame_files] |
|
|
| try: |
| with torch.no_grad(): |
| dataset = FrameDataset(frame_paths) |
| loader = DataLoader( |
| dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0, |
| pin_memory=True, collate_fn=collate_fn |
| ) |
|
|
| all_frame_embeddings = [] |
| for image_batch in loader: |
| if image_batch is None: |
| continue |
| |
| image_inputs = processor(images=image_batch, return_tensors="pt").to(device) |
| frame_embeddings = model.get_image_features(**image_inputs) |
| all_frame_embeddings.append(frame_embeddings) |
|
|
| if not all_frame_embeddings: |
| continue |
|
|
| all_frame_embeddings = torch.cat(all_frame_embeddings, dim=0) |
|
|
| |
| data_to_save = { |
| 'filenames': frame_files, |
| 'embeddings': all_frame_embeddings.cpu() |
| } |
| torch.save(data_to_save, output_path) |
|
|
| except Exception as e: |
| progress_bar.write(f"Error on GPU-{gpu_id} for video '{video_name}': {e}") |
|
|
| def main(): |
| """主函数,用于协调多GPU处理""" |
| args = parse_arguments() |
|
|
| num_gpus = torch.cuda.device_count() |
| if num_gpus == 0: |
| print("错误: 未找到启用CUDA的GPU。正在退出。") |
| exit(1) |
| |
| print(f"找到 {num_gpus} 个GPU。开始并行处理...") |
|
|
| |
| os.makedirs(args.output_dir, exist_ok=True) |
|
|
| |
| video_dirs = [d for d in glob.glob(os.path.join(args.frames_path, '*')) if os.path.isdir(d)] |
| |
| if not video_dirs: |
| print(f"错误: 在 {args.frames_path} 中未找到视频目录。") |
| return |
|
|
| |
| chunk_size = (len(video_dirs) + num_gpus - 1) // num_gpus |
| video_chunks = [video_dirs[i:i + chunk_size] for i in range(0, len(video_dirs), chunk_size)] |
| |
| |
| process_args = [(video_chunks[i], args.frames_path, i, args.output_dir) for i in range(len(video_chunks))] |
|
|
| with mp.Pool(processes=num_gpus) as pool: |
| pool.map(process_video_chunk, process_args) |
|
|
| print("\n所有视频帧嵌入已计算并保存。") |
|
|
| if __name__ == "__main__": |
| mp.set_start_method('spawn', force=True) |
| main() |
|
|