| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import json |
| import os |
| from typing import Any, Dict, List |
|
|
| import numpy as np |
| import torch |
| from torch.utils.data import Dataset |
| import decord |
| from decord import VideoReader |
| from PIL import Image |
|
|
| from data.video.sampler.utils import FRAME_SAMPLER_TYPES |
| from data.video.sampler.frames import FrameSamplerOutput |
| from data.transforms import VideoTransform |
| from data.data_utils import ( |
| get_flattened_position_ids_extrapolate_video, |
| len2weight, |
| patchify_video_with_merge, |
| ) |
| from data.system_prompt_render import render_qwenvl_prompt, expand_and_index_by_token_ids_new |
| from data.common import generate_system_prompt |
| from modeling.qwen2 import Qwen2Tokenizer |
| from config.config_factory import ModelArguments, DataArguments, TrainingArguments |
|
|
| sample_task_map = { |
| 't2v': 0, |
| 'idip': 1, |
| 'edit': 2, |
| 'refedit': 3, |
| } |
| modality_map = { |
| 'system_prompt': -1, |
| 'text': 0, |
| 'noise': 1, |
| 'ref_source': 2, |
| 'ref_image': 3, |
| 'ref_vit': 4 |
| } |
|
|
|
|
| class ValidationDataset(Dataset): |
| def __init__( |
| self, |
| jsonl_path: str, |
| tokenizer: Qwen2Tokenizer, |
| data_args: DataArguments, |
| model_args: ModelArguments, |
| training_args: TrainingArguments, |
| new_token_ids: Dict[str, int], |
| dataset_config: None, |
| local_rank: int = 0, |
| world_size: int = 1, |
| ): |
| """ |
| 初始化验证数据集 |
| |
| Args: |
| jsonl_path: JSONL文件路径 |
| tokenizer: 分词器 |
| """ |
| self.jsonl_path = jsonl_path |
| self.tokenizer = tokenizer |
| self.new_token_ids = new_token_ids |
|
|
| |
| try: |
| full_data = self._read_jsonl() |
| except: |
| with open(jsonl_path, 'r', encoding='utf-8') as f: |
| full_data = json.load(f) |
| if isinstance(full_data, dict): |
| |
| full_data = [{"index": self.pro_index(index), "data": prompt} for index, prompt in full_data.items()] |
|
|
| if world_size > 1: |
| self.data = full_data[local_rank::world_size] |
| print(f"Rank {local_rank}/{world_size} will process {len(self.data)} samples") |
| else: |
| self.data = full_data |
|
|
| self.data_config = dataset_config |
|
|
| self.bos_token_id = self.new_token_ids["bos_token_id"] |
| self.eos_token_id = self.new_token_ids["eos_token_id"] |
| self.start_of_image = self.new_token_ids["start_of_image"] |
| self.end_of_image = self.new_token_ids["end_of_image"] |
| self.image_token_id = self.new_token_ids["image_token_id"] |
|
|
| |
| try: |
| max_duration = self.data_config.max_duration |
| except: |
| max_duration = 6.0 |
|
|
| video_frame_sampler_params = {"temporal": 4, "sample_fps": 12, "max_duration": max_duration, "assert_seconds": False, "truncate": False} |
|
|
| self.frame_sampler = FRAME_SAMPLER_TYPES["multi_clips"](**video_frame_sampler_params) |
| self.cpu_count = os.cpu_count() or 1 |
|
|
| |
| if self.data_config.resolution in ["video_192p", "image_256res"]: |
| resolution_vae = 256 |
| resolution_vit = 224 |
| elif self.data_config.resolution == "image_512res": |
| resolution_vae = 512 |
| resolution_vit = 448 |
| elif self.data_config.resolution == "image_768res": |
| resolution_vae = 768 |
| resolution_vit = 672 |
| elif self.data_config.resolution == "video_360p": |
| resolution_vae = 480 |
| resolution_vit = 476 |
| elif self.data_config.resolution == "video_480p": |
| resolution_vae = 640 |
| resolution_vit = 616 |
| else: |
| raise ValueError(f"Unknown resolution: {self.data_config.resolution}") |
|
|
| video_transform_args = { |
| "resolution": resolution_vae, |
| "mode": "bucket", |
| "divisible_crop_size": 16, |
| "stride_spatial": 16, |
| "stride_temporal": 4, |
| "aspect_ratios": ["21:9", "16:9", "4:3", "1:1", "3:4", "9:16"], |
| "mean": 0.5, |
| "std": 0.5, |
| } |
| self.transform = VideoTransform(**video_transform_args) |
|
|
| |
| vit_video_transform_args = { |
| "resolution": resolution_vit, |
| "mode": "bucket", |
| "divisible_crop_size": 28, |
| "aspect_ratios": ["21:9", "16:9", "4:3", "1:1", "3:4", "9:16"], |
| "mean": [0.48145466, 0.4578275, 0.40821073], |
| "std": [0.26862954, 0.26130258, 0.27577711], |
| } |
| self.vit_transform = VideoTransform(**vit_video_transform_args) |
|
|
| self.sample = self.set_sequence_status() |
|
|
| self.frame_condition_idx = [] |
|
|
| if hasattr(self.data_config, 'system_prompt_type'): |
| self.system_prompt_type = self.data_config.system_prompt_type |
| else: |
| self.system_prompt_type = 'SP0' |
|
|
| def pro_index(self, index: int): |
| if isinstance(index, str): |
| for x in ['.mp4', '.jpg', '.png', '.jpeg']: |
| index = index.replace(x, "") |
| return int(index) |
|
|
| def set_sequence_status(self): |
| sequence_status = dict( |
| curr=0, |
| sample_lens=[], |
| sample_type=[], |
| sample_N_target=[], |
| packed_position_ids=[], |
| nested_attention_masks=[], |
| split_lens=[], |
| attn_modes=[], |
| packed_text_ids=[], |
| packed_text_indexes=[], |
| packed_label_ids=[], |
| ce_loss_indexes=[], |
| ce_loss_weights=[], |
| vae_image_tensors=[], |
| vae_video_tensors=[], |
| packed_latent_position_ids=[], |
| vae_latent_shapes=[], |
| packed_vae_token_indexes=[], |
| packed_timesteps=[], |
| mse_loss_indexes=[], |
| packed_vit_tokens=[], |
| vit_token_seqlens=[], |
| packed_vit_position_ids=[], |
| packed_vit_token_indexes=[], |
| vit_video_grid_thw=[], |
| vae_video_grid_thw=[], |
| video_grid_thw=[], |
| vit_video_tensors=[], |
| |
| vae_video_latent=[], |
| vae_data_mode=[], |
| vit_data_mode=[], |
| key_frame_mask=[], |
| |
| sample_task=[], |
| sample_modality=[], |
| ) |
| return sequence_status |
|
|
| def _read_jsonl(self) -> List[Dict[str, Any]]: |
| """读取JSONL文件""" |
| data = [] |
| with open(self.jsonl_path, "r", encoding="utf-8") as f: |
| for line in f: |
| data.append(json.loads(line.strip())) |
| return data |
|
|
| def __len__(self) -> int: |
| return len(self.data) |
|
|
|
|
| @staticmethod |
| def _read_decord(video: VideoReader, frame_idx: List[int]) -> List[Image.Image]: |
| |
| frames_np = video.get_batch(frame_idx).asnumpy() |
| return [Image.fromarray(frame) for frame in frames_np] |
|
|
| def get_video_tensor_online(self, media_url, vision_stream, worker_id=0, element_dtype="image") -> torch.Tensor: |
| self.vision_stream = vision_stream |
| video_stream = media_url |
|
|
| if element_dtype == "image": |
| image = Image.open(video_stream) |
| if image.mode == "P": |
| image = image.convert("RGBA") |
| if image.mode == "RGBA": |
| |
| bg = Image.new("RGB", image.size, (255, 255, 255)) |
| bg.paste(image, mask=image.split()[3]) |
| image = bg |
| else: |
| image = image.convert("RGB") |
| video_frames = [image] |
| else: |
| video_reader = VideoReader(video_stream, ctx=decord.cpu(worker_id % self.cpu_count)) |
| total_frames = len(video_reader) |
|
|
| sampler_name = self.frame_sampler.__class__.__name__ |
| if sampler_name == "MultiClipsFrameSampler": |
| frames_info = { |
| "clip_indices": [(0, total_frames)], |
| "fps": 24, |
| } |
| elif sampler_name == "FixedFrameSampler": |
| frames_info = { |
| "start_frame": 0, |
| "end_frame": total_frames, |
| "total_frames": total_frames, |
| } |
| else: |
| raise ValueError(f"Not verified frame sampler type: {sampler_name}") |
|
|
| frames_sampler_output: FrameSamplerOutput = self.frame_sampler(frames_info) |
| video_frames = self._read_decord(video_reader, frames_sampler_output.indices) |
|
|
| if vision_stream == "vae_video": |
| video_tensor = self.transform(video_frames) |
| elif vision_stream == "vit_video": |
| video_tensor = self.vit_transform(video_frames) |
| if element_dtype == "image": |
| video_tensor = video_tensor.repeat(1, 2, 1, 1) |
| |
| if video_tensor.shape[1] % 2 == 1: |
| last_frame = video_tensor[:, -1:, :, :] |
| video_tensor = torch.cat([video_tensor, last_frame], dim=1) |
|
|
| else: |
| raise ValueError(f"Unknown vision_stream: {vision_stream}") |
| return video_tensor |
|
|
| def process_vit_video(self, video_tensor, curr: int, curr_rope_id: int, curr_split_len: int, curr_video_grid_thw: None, item_loss=0): |
| if not self.data_config.text_template: |
| self.sample["packed_text_ids"].append(self.start_of_image) |
| self.sample["packed_text_indexes"].append(curr) |
| curr += 1 |
| curr_split_len += 1 |
|
|
| |
| if isinstance(video_tensor, torch.Tensor): |
| self.sample["vit_video_tensors"].append(video_tensor) |
|
|
| |
| vit_tokens = patchify_video_with_merge( |
| video_tensor, self.data_config.vit_patch_size, self.data_config.vit_patch_size_temporal |
| ) |
| num_video_tokens = vit_tokens.shape[0] // 4 |
| t, h, w = video_tensor.size(1), video_tensor.size(2), video_tensor.size(3) |
|
|
| self.sample["packed_vit_tokens"].append(vit_tokens) |
| self.sample["vit_data_mode"].append("online") |
|
|
| if t is not None: |
| vit_video_grid_thw = [ |
| t // self.data_config.vit_patch_size_temporal, |
| h // self.data_config.vit_patch_size, |
| w // self.data_config.vit_patch_size, |
| ] |
| self.sample["vit_video_grid_thw"].append(vit_video_grid_thw) |
| curr_video_grid_thw.append(vit_video_grid_thw) |
|
|
| self.sample["vit_token_seqlens"].append(num_video_tokens) |
| self.sample["packed_vit_position_ids"].append( |
| torch.zeros(num_video_tokens) |
| ) |
|
|
| if not self.data_config.text_template: |
| self.sample["packed_vit_token_indexes"].extend(range(curr, curr + num_video_tokens)) |
| curr += num_video_tokens |
| curr_split_len += num_video_tokens |
|
|
| |
| self.sample["packed_text_ids"].extend([self.image_token_id] * num_video_tokens) |
|
|
| |
| self.sample["packed_text_ids"].append(self.end_of_image) |
| self.sample["packed_text_indexes"].append(curr) |
| curr += 1 |
| curr_split_len += 1 |
| self.sample["packed_position_ids"].extend([curr_rope_id] * curr_split_len) |
| curr_rope_id += 1 |
|
|
| |
| self.sample["attn_modes"].append("full") |
| self.sample["split_lens"].append(curr_split_len) |
|
|
| return self.sample, curr, curr_rope_id, curr_split_len, curr_video_grid_thw, num_video_tokens |
|
|
| def process_text(self, caption: str, curr: int, curr_rope_id: int, curr_split_len: int, item_loss=0): |
| """处理文本,添加特殊token""" |
| text_ids = self.tokenizer.encode(caption) |
| shifted_text_ids = [self.bos_token_id] + text_ids |
|
|
| self.sample["packed_text_ids"].extend(shifted_text_ids) |
| self.sample["packed_text_indexes"].extend(range(curr, curr + len(shifted_text_ids))) |
|
|
| |
| if item_loss == 1: |
| loss_token_shift = 0 |
| self.sample["ce_loss_indexes"].extend(range(curr - loss_token_shift, curr + len(shifted_text_ids))) |
| self.sample["ce_loss_weights"].extend([len2weight(len(shifted_text_ids) + loss_token_shift)] * (len(shifted_text_ids) + loss_token_shift)) |
| self.sample["packed_label_ids"].extend(text_ids + [self.eos_token_id]) |
| curr += len(shifted_text_ids) |
| curr_split_len += len(shifted_text_ids) |
|
|
| |
| self.sample["packed_text_ids"].append(self.eos_token_id) |
| self.sample["packed_text_indexes"].append(curr) |
| curr += 1 |
| curr_split_len += 1 |
|
|
| |
| self.sample["attn_modes"].append("causal") |
| |
| self.sample["packed_position_ids"].extend(range(curr_rope_id, curr_rope_id + curr_split_len)) |
| curr_rope_id += curr_split_len |
|
|
| |
|
|
| self.sample["split_lens"].append(curr_split_len) |
|
|
| return self.sample, curr, curr_rope_id, curr_split_len |
|
|
|
|
| def process_vae_video(self, video_tensor, curr: int, curr_rope_id: int, curr_split_len: int, curr_video_grid_thw: None, video_sizes: list, item_loss=0): |
| if not self.data_config.text_template: |
| num_special_tokens = 0 |
| |
| self.sample["packed_text_ids"].append(self.start_of_image) |
| self.sample["packed_text_indexes"].append(curr) |
| curr += 1 |
| curr_split_len += 1 |
| num_special_tokens += 1 |
|
|
| |
| if isinstance(video_tensor, torch.Tensor): |
| |
| self.sample["vae_video_tensors"].append(video_tensor) |
| |
| _, T, H, W = video_tensor.shape |
| _T, _H, _W = self.data_config.vae_downsample |
| t = (T - 1) // _T + 1 |
| h = H // _H |
| w = W // _W |
| self.sample["vae_data_mode"].append("online") |
|
|
| spatial_merge_size = 2 |
| vae_video_grid_thw = [ |
| t, |
| h * spatial_merge_size, |
| w * spatial_merge_size, |
| ] |
|
|
| self.sample["vae_video_grid_thw"].append(vae_video_grid_thw) |
| curr_video_grid_thw.append(vae_video_grid_thw) |
|
|
| |
| self.sample["vae_latent_shapes"].append((t, h, w)) |
|
|
| |
| |
| packed_latent_position_ids = get_flattened_position_ids_extrapolate_video(t, h, w, max_latent_size=self.data_config.max_latent_size) |
|
|
| self.sample["packed_latent_position_ids"].append(packed_latent_position_ids) |
|
|
| num_vid_tokens = t * h * w |
| if not self.data_config.text_template: |
| self.sample["packed_vae_token_indexes"].extend(range(curr, curr + num_vid_tokens)) |
|
|
| if item_loss == 1: |
| timestep = np.random.randn() |
|
|
| frame_condition_idx = self.frame_condition_idx |
| packed_timesteps = [timestep] * num_vid_tokens |
|
|
| mse_loss_indexes = list(range(curr, curr + num_vid_tokens)) |
| frame_condition_indexes = [] |
| for idx in frame_condition_idx: |
| if idx == -1: |
| idx = t - 1 |
| if idx == 1: |
| continue |
| frame_condition_indexes.extend(mse_loss_indexes[idx * h * w : (idx + 1) * h * w]) |
| packed_timesteps[idx * h * w : (idx + 1) * h * w] = [-sys.float_info.max] * (h * w) |
| if frame_condition_idx: |
| mse_loss_indexes = sorted(list(set(mse_loss_indexes) - set(frame_condition_indexes))) |
|
|
| if not self.data_config.text_template: |
| self.sample["mse_loss_indexes"].extend(mse_loss_indexes) |
| else: |
| timestep = float("-inf") |
| packed_timesteps = [timestep] * num_vid_tokens |
|
|
| self.sample["packed_timesteps"].extend(packed_timesteps) |
|
|
| if not self.data_config.text_template: |
| curr += num_vid_tokens |
| curr_split_len += num_vid_tokens |
|
|
| self.sample["packed_text_ids"].extend([self.image_token_id] * num_vid_tokens) |
|
|
| |
| self.sample["packed_text_ids"].append(self.end_of_image) |
| self.sample["packed_text_indexes"].append(curr) |
| curr += 1 |
| curr_split_len += 1 |
| num_special_tokens += 1 |
|
|
| |
| if item_loss == 1: |
| self.sample["attn_modes"].append("noise") |
| else: |
| self.sample["attn_modes"].append("full_noise") |
|
|
| self.sample["packed_position_ids"].extend([curr_rope_id] * (num_vid_tokens + num_special_tokens)) |
| curr_rope_id += 1 |
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
| self.sample["split_lens"].append(curr_split_len) |
|
|
| video_sizes.append([T, H, W]) |
|
|
| return self.sample, curr, curr_rope_id, curr_split_len, curr_video_grid_thw, video_sizes, num_vid_tokens |
|
|
| def process_text_template( |
| self, |
| text_ids, |
| spans_index, |
| tgt_index, |
| caption_index, |
| video_types: list[str], |
| curr: int, |
| curr_rope_id: int, |
| curr_split_len: int, |
| item_loss=0, |
| ): |
| |
|
|
| self.sample["packed_text_ids"].extend(text_ids) |
| self.sample["sample_lens"] = len(text_ids) |
| curr_split_idx = curr |
|
|
| for video_id, span_index in enumerate(spans_index): |
| vision_start, vision_end = curr_split_idx + span_index[0], curr_split_idx + span_index[-1] |
| self.sample["packed_text_indexes"].extend(range(curr, vision_start)) |
| if (vision_start - 1) - curr != 0: |
| curr_split_len = (vision_start - 1) - curr |
| self.sample["packed_position_ids"].extend( |
| range(curr_rope_id, curr_rope_id + curr_split_len) |
| ) |
| curr_rope_id += curr_split_len |
| self.sample["sample_modality"].extend([modality_map["system_prompt"]] * curr_split_len) |
|
|
| if caption_index != [] and caption_index[0] in range(curr, curr + curr_split_len): |
| split_len_1 = caption_index[0] - curr |
| split_len_2 = len(caption_index) |
| split_len_3 = curr_split_len - split_len_1 - split_len_2 |
|
|
| split_len_text = [split_len_1, split_len_2, split_len_3] |
| split_len_text = [x for x in split_len_text if x != 0] |
| self.sample["attn_modes"].extend(["causal"] * len(split_len_text)) |
| self.sample["split_lens"].extend(split_len_text) |
| else: |
| self.sample["attn_modes"].append("causal") |
| self.sample["split_lens"].append(curr_split_len) |
|
|
| curr_split_len = len(span_index) + 2 |
| if video_types[video_id] == "vit_video": |
| self.sample["packed_vit_token_indexes"].extend(range(vision_start, vision_end + 1)) |
| self.sample["attn_modes"].append("full") |
| self.sample["sample_modality"].extend([modality_map["ref_vit"]] * curr_split_len) |
| elif "vae_video" in video_types[video_id]: |
| self.sample["packed_vae_token_indexes"].extend(range(vision_start, vision_end + 1)) |
| if "cond" in video_types[video_id]: |
| self.sample["attn_modes"].append("full_noise") |
| if self.sample_task == "edit": |
| self.sample["sample_modality"].extend([modality_map["ref_source"]] * curr_split_len) |
| elif self.sample_task == "idip": |
| self.sample["sample_modality"].extend([modality_map["ref_image"]] * curr_split_len) |
| elif "target" in video_types[video_id]: |
| self.sample["mse_loss_indexes"].extend(range(vision_start, vision_end + 1)) |
| self.sample["attn_modes"].append("noise") |
| self.sample["sample_modality"].extend([modality_map["noise"]] * curr_split_len) |
| else: |
| raise ValueError(f"video_types {video_types[video_id]} not supported") |
|
|
| self.sample["packed_position_ids"].extend([curr_rope_id] * curr_split_len) |
| |
| self.sample["split_lens"].append(len(span_index) + 2) |
| curr = vision_end + 1 |
| curr_rope_id += 1 |
| self.sample["packed_text_indexes"].append(curr) |
| curr += 1 |
|
|
| len_split_last = self.sample["sample_lens"] - (curr - curr_split_idx) if spans_index != [] else len(text_ids) |
| if len_split_last != 0: |
| self.sample["split_lens"].append(len_split_last) |
| self.sample["packed_text_indexes"].extend(range(curr, curr + len_split_last)) |
| self.sample["packed_position_ids"].extend(range(curr_rope_id, curr_rope_id + len_split_last)) |
| self.sample["attn_modes"].append("causal") |
| self.sample["sample_modality"].extend([modality_map["system_prompt"]] * len_split_last) |
|
|
| if item_loss == 1: |
| packed_label_index = tgt_index |
| self.sample["packed_label_ids"].extend(text_ids[packed_label_index[0] :]) |
| packed_label_index = np.asarray(packed_label_index, dtype=np.int64) + curr_split_idx |
| ce_loss_indexes = (packed_label_index - 1).tolist() |
| self.sample["ce_loss_indexes"].extend(ce_loss_indexes) |
| self.sample["ce_loss_weights"].extend([len2weight(len(packed_label_index))] * (len(packed_label_index))) |
|
|
| |
| |
| if caption_index != []: |
| self.sample["sample_modality"][caption_index[0] : caption_index[-1] + 1] = [modality_map["text"]] * (caption_index[-1] - caption_index[0] + 1) |
|
|
| curr_split_idx += len(text_ids) |
| curr = curr_split_idx |
| return self.sample, curr, curr_rope_id, curr_split_len |
| def process_und_template(self, system_prompt, user_prompt, answer, vit_video_tensor): |
| """ |
| 格式: |
| <|im_start|>system |
| {system_prompt}<|im_end|> |
| <|im_start|>user |
| <|vision_start|><|video_pad|><|vision_end|>{instruction_prompt}<|im_end|> |
| <|im_start|>assistant |
| {answer}<|im_end|> |
| """ |
| curr = 0 |
| sample_lens = 0 |
| curr_rope_id = 0 |
| curr_video_grid_thw = [] |
|
|
| |
| |
| |
| |
| prompt_prefix = "<|im_start|>" + "system\n" + system_prompt + "<|im_end|>" + "\n" + "<|im_start|>" + "user\n" |
| text_ids_prompt_prefix = self.tokenizer.encode(prompt_prefix) |
| self.sample["packed_text_ids"].extend(text_ids_prompt_prefix) |
| self.sample["packed_text_indexes"].extend(range(curr, curr + len(text_ids_prompt_prefix))) |
| curr += len(text_ids_prompt_prefix) |
| split_len_prefix = len(text_ids_prompt_prefix) |
|
|
| |
| self.sample["attn_modes"].append("causal") |
| self.sample["packed_position_ids"].extend(range(curr_rope_id, curr_rope_id + split_len_prefix)) |
| self.sample["split_lens"].append(split_len_prefix) |
| curr_rope_id += split_len_prefix |
|
|
| |
| self.sample["packed_text_ids"].append(self.start_of_image) |
| self.sample["packed_text_indexes"].append(curr) |
| curr += 1 |
| split_len_vision_token = 1 |
|
|
| if isinstance(vit_video_tensor, torch.Tensor): |
| self.sample["vit_video_tensors"].append(vit_video_tensor) |
|
|
| |
| vit_tokens = patchify_video_with_merge( |
| vit_video_tensor, self.data_config.vit_patch_size, self.data_config.vit_patch_size_temporal |
| ) |
| num_video_tokens = vit_tokens.shape[0] // 4 |
| t, h, w = vit_video_tensor.size(1), vit_video_tensor.size(2), vit_video_tensor.size(3) |
|
|
| self.sample["packed_vit_tokens"].append(vit_tokens) |
| self.sample["vit_data_mode"].append("online") |
|
|
| if t is not None: |
| vit_video_grid_thw = [ |
| t // self.data_config.vit_patch_size_temporal, |
| h // self.data_config.vit_patch_size, |
| w // self.data_config.vit_patch_size, |
| ] |
| self.sample["vit_video_grid_thw"].append(vit_video_grid_thw) |
| curr_video_grid_thw.append(vit_video_grid_thw) |
|
|
| self.sample["vit_token_seqlens"].append(num_video_tokens) |
| self.sample["packed_vit_position_ids"].append( |
| torch.zeros(num_video_tokens) |
| ) |
|
|
| self.sample["packed_vit_token_indexes"].extend(range(curr, curr + num_video_tokens)) |
| curr += num_video_tokens |
| split_len_vision_token += num_video_tokens |
|
|
| |
| self.sample["packed_text_ids"].extend([self.image_token_id] * num_video_tokens) |
|
|
| |
| self.sample["packed_text_ids"].append(self.end_of_image) |
| self.sample["packed_text_indexes"].append(curr) |
| curr += 1 |
| split_len_vision_token += 1 |
|
|
| |
| self.sample["attn_modes"].append("full") |
| self.sample["packed_position_ids"].extend([curr_rope_id] * split_len_vision_token) |
| self.sample["split_lens"].append(split_len_vision_token) |
| curr_rope_id += 1 |
|
|
| |
| |
| |
| prompt_postfix = user_prompt + "<|im_end|>" + "\n" + "<|im_start|>" + "assistant" |
| text_ids_prompt_postfix = self.tokenizer.encode(prompt_postfix) |
| self.sample["packed_text_ids"].extend(text_ids_prompt_postfix) |
| self.sample["packed_text_indexes"].extend(range(curr, curr + len(text_ids_prompt_postfix))) |
| curr += len(text_ids_prompt_postfix) |
| split_len_postfix = len(text_ids_prompt_postfix) |
|
|
| |
| self.sample["attn_modes"].append("causal") |
| self.sample["packed_position_ids"].extend(range(curr_rope_id, curr_rope_id + split_len_postfix)) |
| self.sample["split_lens"].append(split_len_postfix) |
| curr_rope_id += split_len_postfix |
|
|
| |
| answer = "\n" + answer |
| answer_ids = self.tokenizer.encode(answer) |
| shifted_text_ids_answer = answer_ids + [self.eos_token_id] |
| self.sample["packed_text_ids"].extend(shifted_text_ids_answer) |
| self.sample["packed_text_indexes"].extend(range(curr, curr + len(shifted_text_ids_answer))) |
|
|
| |
| self.sample["ce_loss_indexes"].extend(range(curr, curr + len(shifted_text_ids_answer))) |
| self.sample["ce_loss_weights"].extend([len2weight(len(shifted_text_ids_answer))] * (len(shifted_text_ids_answer))) |
| self.sample["packed_label_ids"].extend(shifted_text_ids_answer) |
|
|
| curr += len(shifted_text_ids_answer) |
| split_len_answer = len(shifted_text_ids_answer) |
|
|
| |
| self.sample["attn_modes"].append("causal") |
| self.sample["packed_position_ids"].extend(range(curr_rope_id, curr_rope_id + split_len_answer)) |
| self.sample["split_lens"].append(split_len_answer) |
| curr_rope_id += split_len_answer |
|
|
| sample_lens = len(self.sample["packed_text_ids"]) |
|
|
| return sample_lens, curr_video_grid_thw |
|
|
| def _finalize_sample(self, sample_lens, curr_video_grid_thw, sample_type, sample=None, additional_fields=None, video_sizes=None): |
| """通用 sample 结尾处理,减少代码重复""" |
| self.sample["sample_lens"] = [sample_lens] |
| self.sample["video_grid_thw"] = torch.tensor([curr_video_grid_thw]) |
| self.sample["packed_text_ids"] = torch.tensor(self.sample["packed_text_ids"]) |
| self.sample["packed_text_indexes"] = torch.tensor(self.sample["packed_text_indexes"]) |
|
|
| self.sample["packed_vae_token_indexes"] = torch.tensor(self.sample["packed_vae_token_indexes"]) |
| self.sample["packed_position_ids"] = torch.tensor(self.sample["packed_position_ids"]) |
| self.sample["vae_video_grid_thw"] = torch.tensor(self.sample["vae_video_grid_thw"]) |
|
|
| self.sample["vit_video_grid_thw"] = torch.tensor(self.sample["vit_video_grid_thw"]) |
| self.sample["packed_vit_token_indexes"] = torch.tensor(self.sample["packed_vit_token_indexes"]) |
|
|
| self.sample["sample_N_target"] = torch.tensor([[1]]) |
| self.sample["sample_type"] = [sample_type] |
| self.sample["padded_videos"] = self.sample["vae_video_tensors"] |
|
|
| if "ce_loss_indexes" in self.sample and len(self.sample["ce_loss_indexes"]) > 0: |
| self.sample["ce_loss_indexes"] = torch.tensor(self.sample["ce_loss_indexes"]) |
| |
| self.sample["mse_loss_indexes"] = torch.tensor(self.sample["mse_loss_indexes"]) |
| if video_sizes is not None: |
| self.sample["video_sizes"] = torch.tensor(video_sizes) |
| elif "video_sizes" in self.sample: |
| self.sample["video_sizes"] = torch.tensor(self.sample["video_sizes"]) |
| if "sample_modality" in self.sample and len(self.sample["sample_modality"]) > 0: |
| self.sample["sample_modality"] = torch.tensor(self.sample["sample_modality"]) |
|
|
| if sample is not None: |
| for key in ["index", "category", "question", "gt"]: |
| if key in sample: |
| self.sample[key] = sample[key] |
|
|
| if additional_fields is not None: |
| for key, value in additional_fields.items(): |
| self.sample[key] = value |
|
|
| return self.sample |
|
|
| def ti2t_sample(self, idx: int) -> Dict[str, Any]: |
| """ |
| 获取单个样本 |
| 默认system_prompt和user_prompt中均不包含sos和eos token |
| 格式: |
| <|im_start|>system |
| {system_prompt}<|im_end|> |
| <|im_start|>user |
| <|vision_start|><|video_pad|><|vision_end|>{instruction_prompt}<|im_end|> |
| <|im_start|>assistant |
| {answer}<|im_end|> |
| """ |
| self.sample = self.set_sequence_status() |
| sample = self.data[idx] |
|
|
| system_prompt = sample["system_prompt"] |
| user_prompt = sample["user_prompt"] |
| answer = sample["gt"] |
| image_path = sample["image_path"] |
| vit_image_tensor = self.get_video_tensor_online(image_path, vision_stream="vit_video", element_dtype="image") |
|
|
| sample_lens, curr_video_grid_thw = self.process_und_template( |
| system_prompt=system_prompt, |
| user_prompt=user_prompt, |
| answer=answer, |
| vit_video_tensor=vit_image_tensor, |
| ) |
|
|
| self.sample["system_prompt"] = system_prompt |
| self.sample["user_prompt"] = user_prompt |
| self.sample["image_path"] = image_path |
| self.sample["instruction"] = user_prompt |
|
|
| return self._finalize_sample( |
| sample_lens, curr_video_grid_thw, |
| sample_type="und", |
| sample=sample |
| ) |
|
|
| def t2v_sample(self, idx: int) -> Dict[str, Any]: |
| """获取单个样本""" |
| _T, _H, _W = self.data_config.vae_downsample |
| if self.data_config.task == "t2i": |
| t = 1 |
| t_ = 1 |
| element_dtype = 'image' |
| else: |
| t = (self.data_config.num_frames - 1) // _T + 1 |
| t_ = self.data_config.num_frames |
| element_dtype = 'video' |
|
|
| self.sample = self.set_sequence_status() |
| packed_text_indexes, packed_position_ids, sample_modality = [], [], [] |
| sample = self.data[idx] |
| if "prompt_en" in sample.keys(): |
| user_prompt = "".join(sample["prompt_en"][0]) |
| |
| else: |
| user_prompt = sample["data"] |
|
|
| if self.data_config.text_template: |
| caption_instruction = generate_system_prompt(system_prompt_type=self.data_config.task, vision_type=element_dtype) |
|
|
| text_template_user, text_template_assistant, vit_num_tokens, video_types = [], [], [], [] |
| if self.system_prompt_type == 'SP2': |
| user_prompt = caption_instruction + " " + user_prompt |
| caption_instruction = "You are a helpful assistant. " |
| elif self.system_prompt_type == 'SP1': |
| |
| caption_instruction = "You are a helpful assistant. " + caption_instruction |
|
|
| text_template_user.append({"type": "text", "text": user_prompt}) |
| else: |
| |
| text_ids = self.tokenizer.encode(user_prompt) |
| text_ids = [self.new_token_ids["bos_token_id"]] + text_ids + [self.new_token_ids["eos_token_id"]] |
| text_split_len = len(text_ids) |
| packed_text_indexes.extend(range(0, text_split_len)) |
| packed_position_ids.extend(range(0, text_split_len)) |
| sample_modality.extend([modality_map['text']] * text_split_len) |
|
|
| |
|
|
| h = self.data_config.H // _H |
| w = self.data_config.W // _W |
| spatial_merge_size = 2 |
| |
| num_vid_tokens = t * h * w |
|
|
| if self.data_config.text_template: |
| text_template_assistant.append({"type":element_dtype}) |
| else: |
| text_ids.append(self.new_token_ids["start_of_image"]) |
| packed_text_indexes.append(text_split_len) |
| packed_vae_token_indexes = torch.tensor(range(len(text_ids), len(text_ids) + num_vid_tokens)) |
| text_ids.extend([self.image_token_id] * num_vid_tokens) |
| text_ids.append(self.new_token_ids["end_of_image"]) |
| packed_text_indexes.append(len(text_ids) - 1) |
| video_split_len = num_vid_tokens + 2 |
| packed_position_ids.extend([text_split_len] * video_split_len) |
| sample_modality.extend([modality_map['noise']] * video_split_len) |
|
|
| if self.data_config.text_template: |
| all_token_id, spans_index, tgt_index, search_index = self.render_template(caption_instruction, text_template_assistant, text_template_user, [num_vid_tokens], search_text=user_prompt) |
|
|
| |
| self.sample, curr, curr_rope_id, curr_split_len = self.process_text_template( |
| all_token_id, |
| spans_index, |
| tgt_index, |
| search_index, |
| video_types=['target_vae_video'], |
| curr=0, |
| curr_rope_id=0, |
| curr_split_len=0, |
| item_loss=0, |
| ) |
|
|
| |
| return { |
| "packed_text_ids": torch.tensor(text_ids) if not self.data_config.text_template else torch.tensor(self.sample["packed_text_ids"]), |
| "packed_text_indexes": torch.tensor(packed_text_indexes) if not self.data_config.text_template else torch.tensor(self.sample["packed_text_indexes"]), |
| "packed_vae_token_indexes": packed_vae_token_indexes if not self.data_config.text_template else torch.tensor(self.sample["packed_vae_token_indexes"]), |
| "vae_video_grid_thw": torch.tensor([[t, h * spatial_merge_size, w * spatial_merge_size]]), |
| "video_grid_thw": torch.tensor([[[t, h * spatial_merge_size, w * spatial_merge_size]]]), |
| "sample_N_target": torch.tensor([[1]]), |
| "split_lens": [text_split_len, video_split_len] if not self.data_config.text_template else self.sample["split_lens"], |
| "attn_modes": ["causal", "noise"] if not self.data_config.text_template else self.sample["attn_modes"], |
| "sample_lens": [text_split_len + video_split_len] if not self.data_config.text_template else [self.sample["sample_lens"]], |
| "val_sample_type": ["gen"], |
| "padded_latent": None, |
| "mse_loss_indexes": packed_vae_token_indexes if not self.data_config.text_template else torch.tensor(self.sample["mse_loss_indexes"]), |
| "video_sizes": torch.tensor([[t_, self.data_config.H, self.data_config.W]]), |
| "packed_position_ids": torch.tensor(packed_position_ids) if not self.data_config.text_template else torch.tensor(self.sample["packed_position_ids"]), |
| "caption": user_prompt, |
| "sample_type": ["gen"], |
| "index": sample["index"], |
| "caption_cn": user_prompt, |
| "original_prompt_en": sample["original_prompt_en"] if "original_prompt_en" in sample.keys() else user_prompt, |
| "sample_task": torch.zeros(text_split_len + video_split_len) if not self.data_config.text_template else torch.zeros(self.sample["sample_lens"]), |
| "sample_modality": torch.tensor(sample_modality) if not self.data_config.text_template else torch.tensor(self.sample["sample_modality"]), |
| "additional_info": sample["additional_info"] if "additional_info" in sample.keys() else None, |
| } |
|
|
| def tv2v_sample(self, idx: int) -> Dict[str, Any]: |
| """获取单个样本 - 使用 tiv2v_sample 的通用 interleave 格式""" |
| sample = self.data[idx] |
| user_prompt = "Create a 2D animation based on the provided image of a maze. The blue star slides smoothly along the white path, stopping perfectly on the red flag and then acquiring a trophy. The blue star never slides or crosses into the black segments of the maze. The camera is a static, top-down view showing the entire maze." |
| |
| |
| sample["data"] = { |
| "interleave_array": [user_prompt, sample["image_path"], sample["image_path"], sample["video_path"]], |
| "element_dtype_array": ["text", "image", "image", "video"], |
| "istarget_in_interleave": [0, 0, 0, 1] |
| } |
| |
| self.sample_task = 'edit' |
| result = self.tiv2v_sample(idx) |
| |
| |
| result["caption"] = user_prompt |
| result["caption_cn"] = user_prompt |
| |
| return result |
|
|
| def tiv2v_sample(self, idx: int) -> Dict[str, Any]: |
| """获取单个样本""" |
| sample_modality, text_template_user, text_template_assistant, vit_num_tokens, video_types = [], [], [], [], [] |
| self.sample = self.set_sequence_status() |
| sample_lens = 0 |
| sample = self.data[idx] |
|
|
| index = sample["index"] |
| data_sample = sample["data"] |
| additional_info = sample["data"]["additional_info"] if "additional_info" in sample["data"] else [] |
|
|
| interleave_array, element_dtype_array, istarget_in_interleave = data_sample["interleave_array"], data_sample["element_dtype_array"], data_sample["istarget_in_interleave"] |
|
|
| curr, curr_rope_id, curr_split_len, curr_video_grid_thw, video_sizes, caption_all = 0, 0, 0, [], [], '' |
| for element, element_dtype, is_target in zip(interleave_array, element_dtype_array, istarget_in_interleave): |
| if element_dtype == "text": |
| |
| caption_all += element |
| if self.data_config.text_template: |
| text_template_user.append({"type": "text", "text": element}) |
| search_text = element |
| else: |
| self.sample, curr, curr_rope_id, curr_split_len = self.process_text(element, curr=curr, curr_rope_id=curr_rope_id, curr_split_len=0, item_loss=is_target) |
| sample_lens += curr_split_len |
| sample_modality.extend([modality_map['text']] * curr_split_len) |
| elif element_dtype in ["image", "video"]: |
| if is_target == 0: |
| vit_image_tensor = self.get_video_tensor_online(element, vision_stream="vit_video", element_dtype=element_dtype) |
| self.sample, curr, curr_rope_id, curr_split_len, curr_video_grid_thw, num_tokens_ = self.process_vit_video( |
| vit_image_tensor, curr=curr, curr_rope_id=curr_rope_id, curr_split_len=0, curr_video_grid_thw=curr_video_grid_thw, item_loss=0 |
| ) |
| if self.data_config.text_template: |
| text_template_user.append({"type": element_dtype}) |
| vit_num_tokens.append(num_tokens_) |
| video_types.append("vit_video") |
| else: |
| sample_lens += curr_split_len |
| sample_modality.extend([modality_map['ref_vit']] * curr_split_len) |
|
|
| |
| vae_image_tensor = self.get_video_tensor_online(element, vision_stream="vae_video", element_dtype=element_dtype) |
| self.sample, curr, curr_rope_id, curr_split_len, curr_video_grid_thw, video_sizes, num_tokens_ = self.process_vae_video( |
| vae_image_tensor, curr=curr, curr_rope_id=curr_rope_id, curr_split_len=0, curr_video_grid_thw=curr_video_grid_thw, video_sizes=video_sizes, item_loss=is_target |
| ) |
| if self.data_config.text_template: |
| vit_num_tokens.append(num_tokens_) |
| if is_target == 0: |
| text_template_user.append({"type": element_dtype}) |
| video_types.append("cond_vae_video") |
| else: |
| text_template_assistant.append({"type": element_dtype}) |
| video_types.append("target_vae_video") |
| else: |
| sample_lens += curr_split_len |
| if is_target == 0: |
| sample_modality.extend([modality_map[f'ref_{element_dtype}']] * curr_split_len) |
| else: |
| sample_modality.extend([modality_map[f'noise']] * curr_split_len) |
|
|
| if self.data_config.text_template: |
| if text_template_user[0]['type']=='text': |
| text_template_user = text_template_user[1:] + text_template_user[:1] |
| caption_instruction = generate_system_prompt(system_prompt_type=self.data_config.task, vision_type=element_dtype) |
| all_token_id, spans_index, tgt_index, search_index = self.render_template(caption_instruction, text_template_assistant, text_template_user, vit_num_tokens, search_text=search_text) |
| |
| self.sample, curr, curr_rope_id, curr_split_len = self.process_text_template( |
| all_token_id, |
| spans_index, |
| tgt_index, |
| search_index, |
| video_types=video_types, |
| curr=0, |
| curr_rope_id=0, |
| curr_split_len=0, |
| item_loss=0, |
| ) |
| sample_lens = len(all_token_id) |
| sample_modality = self.sample["sample_modality"] |
|
|
|
|
| additional_fields = { |
| "caption": caption_all, |
| "caption_cn": caption_all, |
| "index": sample["index"], |
| "additional_info": additional_info |
| } |
|
|
| if self.sample_task == 'edit': |
| self.sample["sample_task"] = torch.ones(sample_lens) * sample_task_map['edit'] |
| elif self.sample_task == 'idip': |
| self.sample["sample_task"] = torch.ones(sample_lens) * sample_task_map['idip'] |
|
|
| return self._finalize_sample( |
| sample_lens, curr_video_grid_thw, |
| sample_type="gen", |
| sample=sample, |
| additional_fields=additional_fields, |
| video_sizes=video_sizes |
| ) |
|
|
| def render_template(self, instruction, text_template_assistant, text_template_user, vit_num_tokens, search_text=""): |
| |
| |
|
|
| |
| |
| |
|
|
| messages = [ |
| { |
| "role": "user", |
| "content": text_template_user, |
| }, |
| { |
| "role": "assistant", |
| "content": text_template_assistant, |
| }, |
| ] |
| caption_all = render_qwenvl_prompt(messages, default_system=instruction, include_assistant_content=True) |
|
|
| all_token_id, spans_index, tgt_index, search_index = expand_and_index_by_token_ids_new( |
| rendered_text=caption_all.strip(), tokens=vit_num_tokens, target_text=f"assistant\n", tokenizer=self.tokenizer, search_text=search_text |
| ) |
| assert len(all_token_id[tgt_index[0] :]) == len(tgt_index) |
| return all_token_id, spans_index, tgt_index, search_index |
|
|
| def x2t_sample(self, idx: int) -> Dict[str, Any]: |
| """获取单个样本""" |
| sample_modality = [] |
| self.sample = self.set_sequence_status() |
| sample_lens = 0 |
| sample = self.data[idx] |
| index = sample["index"] |
| data_sample = sample["data"] |
|
|
| interleave_array, element_dtype_array, istarget_in_interleave = data_sample["interleave_array"], data_sample["element_dtype_array"], data_sample["istarget_in_interleave"] |
|
|
| curr, curr_rope_id, curr_split_len, curr_video_grid_thw, video_sizes, caption_all = 0, 0, 0, [], [], "" |
| if self.data_config.text_template: |
| text_template_user, text_template_assistant, vit_num_tokens, video_types = [], [], [], [] |
| for element, element_dtype, is_target in zip(interleave_array, element_dtype_array, istarget_in_interleave): |
| if element_dtype == "text": |
| |
| if is_target == 1: |
| if self.data_config.text_template: |
| if isinstance(element, str): |
| caption_a = element |
| caption_i = generate_system_prompt(system_prompt_type="caption", vision_type=element_dtype_array[0]) |
| caption_q = "" |
| element = [caption_i, caption_q, caption_a] |
|
|
| |
| caption_i, caption_q, caption_a = element[0], element[1], element[2] |
| if self.system_prompt_type == 'SP2': |
| caption_q = caption_i + " " + caption_q |
| caption_i = "You are a helpful assistant. " |
| elif self.system_prompt_type == 'SP1': |
| |
| caption_i = "You are a helpful assistant. " + caption_i |
| element = [caption_i, caption_q, caption_a] |
|
|
| print('element',element) |
| |
|
|
| caption_i, caption_q, caption_a = element[0], element[1], element[2] |
|
|
| text_template_assistant.append({"type": "text", "text": caption_a}) |
| if caption_q != "": |
| text_template_user.append({"type": "text", "text": caption_q}) |
|
|
| all_token_id, spans_index, tgt_index, search_index = self.render_template(caption_i, text_template_assistant, text_template_user, vit_num_tokens) |
| self.sample, curr, curr_rope_id, curr_split_len = self.process_text_template( |
| all_token_id, |
| spans_index, |
| tgt_index, |
| search_index, |
| video_types, |
| curr=curr, |
| curr_rope_id=curr_rope_id, |
| curr_split_len=0, |
| item_loss=is_target, |
| ) |
| sample_lens += curr_split_len |
|
|
| caption_all += "\n".join(element) |
| caption_answer = element[-1] |
| else: |
| if isinstance(element, list): |
| element = element[-1] |
| self.sample, curr, curr_rope_id, curr_split_len = self.process_text( |
| element, curr=curr, curr_rope_id=curr_rope_id, curr_split_len=0, item_loss=is_target |
| ) |
| sample_lens += curr_split_len |
| sample_modality.extend([modality_map["text"]] * curr_split_len) |
| caption_all += element |
| caption_answer = element |
|
|
| elif element_dtype in ["image", "video"]: |
|
|
| vit_image_tensor = self.get_video_tensor_online(element, vision_stream="vit_video", element_dtype=element_dtype) |
| self.sample, curr, curr_rope_id, curr_split_len, curr_video_grid_thw, num_tokens_ = self.process_vit_video( |
| vit_image_tensor, curr=curr, curr_rope_id=curr_rope_id, curr_split_len=0, curr_video_grid_thw=curr_video_grid_thw, item_loss=0 |
| ) |
| sample_lens += curr_split_len |
| sample_modality.extend([modality_map["ref_vit"]] * curr_split_len) |
| index_video_path_name = element.split("/")[-1] |
|
|
| if self.data_config.text_template: |
| text_template_user.append({"type": element_dtype}) |
| vit_num_tokens.append(num_tokens_) |
| video_types.append("vit_video") |
|
|
| if self.sample["sample_lens"] != []: |
| sample_lens = self.sample["sample_lens"] |
|
|
| if self.sample["sample_modality"] != []: |
| sample_modality = self.sample["sample_modality"] |
| self.sample["sample_modality"] = sample_modality |
| self.sample["sample_task"] = torch.ones(self.sample["sample_lens"]) * sample_task_map["t2v"] |
|
|
| additional_fields = { |
| "caption": caption_all, |
| "caption_cn": caption_all, |
| "caption_answer": caption_answer, |
| "index_item": index, |
| "index": index_video_path_name, |
| "additional_information": data_sample["additional_information"] if "additional_information" in data_sample.keys() else {}, |
| "visual_path": data_sample["interleave_array"][0], |
| "question": data_sample["interleave_array"][1][1] if isinstance(data_sample["interleave_array"][1], list) and len(data_sample["interleave_array"][1]) > 1 else None, |
| "answer": data_sample["interleave_array"][1][2] if isinstance(data_sample["interleave_array"][1], list) and len(data_sample["interleave_array"][1]) > 2 else None |
| } |
|
|
| return self._finalize_sample( |
| sample_lens, curr_video_grid_thw, |
| sample_type="und", |
| additional_fields=additional_fields |
| ) |
|
|
| def __getitem__(self, idx: int) -> Dict[str, Any]: |
| if self.data_config.task == "tv2v": |
| return self.tv2v_sample(idx) |
| elif self.data_config.task in ["t2i","t2v"]: |
| return self.t2v_sample(idx) |
| elif self.data_config.task == "ti2t": |
| return self.ti2t_sample(idx) |
| elif "tiv2v" in self.data_config.task: |
| if 'edit' in self.data_config.task: |
| self.sample_task = 'edit' |
| elif 'idip' in self.data_config.task: |
| self.sample_task = 'idip' |
| return self.tiv2v_sample(idx) |
| elif self.data_config.task == "video_edit": |
| self.sample_task = 'edit' |
| return self.tiv2v_sample(idx) |
| elif self.data_config.task == "video_idip" or self.data_config.task == "video_idip_multiref": |
| self.sample_task = 'idip' |
| return self.tiv2v_sample(idx) |
| elif self.data_config.task == "image_edit": |
| self.sample_task = 'edit' |
| return self.tiv2v_sample(idx) |
| elif self.data_config.task == "image_idip": |
| self.sample_task = 'idip' |
| return self.tiv2v_sample(idx) |
| elif self.data_config.task in ["x2t", "x2t_image", "x2t_video"]: |
| return self.x2t_sample(idx) |
| else: |
| raise ValueError(f"Unknown task: {self.data_config.task}") |
|
|