| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import warnings |
| warnings.filterwarnings("ignore", message=".*pkg_resources is deprecated.*", category=UserWarning) |
| warnings.filterwarnings("ignore", category=FutureWarning, module="diffusers.models.transformers.transformer_2d") |
| import os |
| import time |
| os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" |
| os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0" |
|
|
| import os.path as osp |
| from copy import deepcopy |
| import json |
| from typing import Tuple, cast, Optional |
| import torch |
| import torch.distributed as dist |
| from torch.utils.data import DataLoader |
| from transformers import HfArgumentParser, set_seed |
| from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import Qwen2_5_VLVisionConfig |
| from safetensors.torch import load_file |
|
|
| from data.dataset_base import DataConfig, simple_custom_collate |
| from data.data_utils import add_special_tokens |
| from modeling.vae.wan.model import WanVideoVAE |
| from modeling.lance import LanceConfig, Lance, Qwen2ForCausalLM |
| from modeling.qwen2 import Qwen2Tokenizer |
| from modeling.qwen2.modeling_qwen2 import Qwen2Config |
| from modeling.vit.qwen2_5_vl_vit import Qwen2_5_VisionTransformerPretrainedModel |
| from common.utils.misc import tuple_mul, AutoEncoderParams |
| from common.utils.distributed import get_global_rank |
| from common.utils.logging import get_logger |
| from common.val.utils import make_padded_latent, decode_video_tensor |
| from data.datasets_custom import ValidationDataset |
| from config.config_factory import ( |
| ModelArguments, |
| DataArguments, |
| InferenceArguments, |
| get_model_path, |
| ) |
|
|
| from tqdm import trange |
|
|
|
|
| |
| MAX_GENERATION_LENGTH = 256 |
| PROMPT_JSON_FILENAME = "prompt.json" |
| RESULT_JSON_FILENAME = "result.json" |
| INTERNAL_VALIDATION_MAX_SAMPLES = 100000 |
| TASK_T2V = "t2v" |
| TASK_T2I = "t2i" |
| TASK_X2T_IMAGE = "x2t_image" |
| TASK_X2T_VIDEO = "x2t_video" |
| TASK_IMAGE_EDIT = "image_edit" |
| TASK_VIDEO_EDIT = "video_edit" |
| GENERATION_TASKS = { |
| TASK_T2V, |
| TASK_T2I, |
| TASK_IMAGE_EDIT, |
| TASK_VIDEO_EDIT, |
| } |
| UNDERSTANDING_TASKS = { |
| TASK_X2T_IMAGE, |
| TASK_X2T_VIDEO, |
| } |
| TASK_DEFAULT_CONFIGS = { |
| TASK_T2I: { |
| "model_family": "image", |
| "example_json": "config/examples/t2i_example.json", |
| "save_path_prefix": "results/t2i_sample", |
| }, |
| TASK_T2V: { |
| "model_family": "video", |
| "example_json": "config/examples/t2v_example.json", |
| "save_path_prefix": "results/t2v_sample", |
| }, |
| TASK_IMAGE_EDIT: { |
| "model_family": "image", |
| "example_json": "config/examples/image_edit_example.json", |
| "save_path_prefix": "results/image_edit_sample", |
| }, |
| TASK_VIDEO_EDIT: { |
| "model_family": "video", |
| "example_json": "config/examples/video_edit_example.json", |
| "save_path_prefix": "results/video_edit_sample", |
| }, |
| TASK_X2T_IMAGE: { |
| "model_family": "image", |
| "example_json": "config/examples/x2t_image_example.json", |
| "save_path_prefix": "results/x2t_image_sample", |
| }, |
| TASK_X2T_VIDEO: { |
| "model_family": "video", |
| "example_json": "config/examples/x2t_video_example.json", |
| "save_path_prefix": "results/x2t_video_sample", |
| }, |
| } |
|
|
| def init_from_model_path_if_needed(model: Qwen2ForCausalLM, model_args: ModelArguments): |
| |
| path_dir = model_args.model_path |
| ema_path = osp.join(path_dir, "ema.safetensors") |
| model_path = osp.join(path_dir, "model.safetensors") |
|
|
|
|
| model_path_ft = None |
| if osp.exists(model_path): |
| model_path_ft = model_path |
| elif osp.exists(ema_path): |
| model_path_ft = ema_path |
|
|
| if model_path_ft: |
| model_state_dict = load_file(model_path_ft, device="cpu") |
| else: |
| raise FileNotFoundError( |
| f"Fine-tuning failed: No valid checkpoint ('ema.safetensors' or 'model.safetensors') found in {path_dir}" |
| ) |
|
|
| |
| |
| if 'latent_pos_embed.pos_embed' in model_state_dict: |
| model_state_dict.pop('latent_pos_embed.pos_embed') |
|
|
| msg = model.load_state_dict(model_state_dict, strict=False) |
| clean_memory(model_state_dict) |
|
|
| return msg |
|
|
|
|
| def clean_memory(*objects): |
| """Clear temporary container references and release unused GPU allocator cache.""" |
| for obj in objects: |
| if isinstance(obj, dict): |
| obj.clear() |
| elif isinstance(obj, (list, set)): |
| obj.clear() |
| import gc |
| gc.collect() |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
|
|
|
|
| def apply_inference_defaults( |
| model_args: ModelArguments, |
| data_args: DataArguments, |
| inference_args: InferenceArguments, |
| ) -> None: |
| if inference_args.task not in TASK_DEFAULT_CONFIGS: |
| raise ValueError(f"Unsupported inference task: {inference_args.task}") |
|
|
| task_config = TASK_DEFAULT_CONFIGS[inference_args.task] |
| default_inference_args = InferenceArguments() |
|
|
| model_family = task_config.get("model_family", "") |
| if not model_args.model_path and model_family: |
| model_args.model_path = get_model_path(f"lance.{model_family}") |
| if not getattr(model_args, "llm_path", ""): |
| model_args.llm_path = model_args.model_path |
| if not model_args.vit_path: |
| model_args.vit_path = get_model_path("vit.qwen2_5_vl") |
|
|
| if not data_args.val_dataset_config_file and task_config.get("example_json"): |
| data_args.val_dataset_config_file = task_config["example_json"] |
|
|
| if inference_args.save_path_gen == default_inference_args.save_path_gen and task_config.get("save_path_prefix"): |
| inference_args.save_path_gen = task_config["save_path_prefix"] |
| if inference_args.validation_max_samples == default_inference_args.validation_max_samples: |
| inference_args.validation_max_samples = INTERNAL_VALIDATION_MAX_SAMPLES |
| if inference_args.video_height == default_inference_args.video_height: |
| inference_args.video_height = int(task_config.get("video_height", default_inference_args.video_height)) |
| if inference_args.video_width == default_inference_args.video_width: |
| inference_args.video_width = int(task_config.get("video_width", default_inference_args.video_width)) |
| if inference_args.resolution == default_inference_args.resolution: |
| inference_args.resolution = task_config.get("resolution", default_inference_args.resolution) |
| if inference_args.text_template == default_inference_args.text_template: |
| inference_args.text_template = bool(task_config.get("text_template", default_inference_args.text_template)) |
|
|
|
|
| def save_prompt_results(prompt_data_dict, save_path_gen, logger): |
| """Save validation results to a JSON file.""" |
| prompt_json_path = os.path.join(save_path_gen, PROMPT_JSON_FILENAME) |
| with open(prompt_json_path, 'w', encoding='utf-8') as f: |
| json.dump(prompt_data_dict, f, ensure_ascii=False, indent=2) |
|
|
|
|
| def normalize_understanding_answer(text: Optional[str]) -> str: |
| """Normalize generated understanding text before exporting it.""" |
| if text is None: |
| return "" |
| return text.replace("<|im_end|>", "").strip() |
|
|
|
|
| def save_understanding_results( |
| prompt_data_dict: dict, |
| dataset_config_file: str, |
| save_path_gen: str, |
| ) -> None: |
| """Save x2t results as a structured result.json file.""" |
| with open(dataset_config_file, "r", encoding="utf-8") as f: |
| dataset_samples = json.load(f) |
|
|
| result_entries = [] |
| for sample_key, sample in dataset_samples.items(): |
| interleave_array = sample.get("interleave_array", []) |
| element_dtype_array = sample.get("element_dtype_array", []) |
| if len(interleave_array) < 2 or not element_dtype_array: |
| continue |
|
|
| visual_path = interleave_array[0] |
| text_payload = interleave_array[1] |
| question = text_payload[1] if isinstance(text_payload, list) and len(text_payload) > 1 else "" |
| modality = element_dtype_array[0] |
|
|
| lookup_keys = [os.path.basename(visual_path), sample_key] |
| generated_answer = "" |
| for lookup_key in lookup_keys: |
| if lookup_key in prompt_data_dict: |
| generated_answer = prompt_data_dict[lookup_key] |
| break |
|
|
| result_entries.append( |
| { |
| modality: visual_path, |
| "question": question, |
| "answer": normalize_understanding_answer(generated_answer), |
| } |
| ) |
|
|
| result_json_path = os.path.join(save_path_gen, RESULT_JSON_FILENAME) |
| with open(result_json_path, "w", encoding="utf-8") as f: |
| json.dump(result_entries, f, ensure_ascii=False, indent=2) |
|
|
|
|
| def validate_on_fixed_batch( |
| fsdp_model: Lance, |
| vae_model: Optional[WanVideoVAE], |
| tokenizer: Qwen2Tokenizer, |
| val_data_cpu: dict, |
| training_args: InferenceArguments, |
| model_args: ModelArguments, |
| inference_args: InferenceArguments, |
| new_token_ids, |
| image_token_id: int, |
| device: int, |
| save_source_video: bool = False, |
| save_path_gen: str = "", |
| save_path_gt: str = "", |
| ): |
| val_data = val_data_cpu.cuda(device).to_dict() |
| fsdp_model = fsdp_model.to(device=device, dtype=torch.bfloat16) |
|
|
| with torch.no_grad(), torch.amp.autocast("cuda", enabled=True, dtype=torch.bfloat16): |
| |
| if "padded_videos" in val_data.keys(): |
| val_data["padded_latent"] = make_padded_latent(val_data["padded_videos"], val_data["vae_data_mode"], vae_model) |
|
|
| |
| if inference_args.task in GENERATION_TASKS: |
| params = { |
| "val_packed_text_ids": val_data["packed_text_ids"], |
| "val_packed_text_indexes": val_data["packed_text_indexes"], |
| "val_sample_lens": val_data["sample_lens"], |
| "val_packed_position_ids": val_data["packed_position_ids"], |
| "val_split_lens": val_data["split_lens"], |
| "val_attn_modes": val_data["attn_modes"], |
| "val_sample_N_target": val_data["sample_N_target"], |
| "val_packed_vae_token_indexes": val_data["packed_vae_token_indexes"], |
| "timestep_shift": training_args.validation_timestep_shift, |
| "num_timesteps": training_args.validation_num_timesteps, |
| "val_mse_loss_indexes": val_data.get("mse_loss_indexes", None), |
| "val_padded_latent": val_data["padded_latent"], |
| "video_sizes": val_data["video_sizes"], |
| "cfg_text_scale": model_args.cfg_text_scale, |
| "cfg_interval": training_args.cfg_interval, |
| "cfg_renorm_min": training_args.cfg_renorm_min, |
| "cfg_renorm_type": training_args.cfg_renorm_type, |
| "device": device, |
| "dtype": torch.bfloat16, |
| "new_token_ids": new_token_ids, |
| "max_samples": training_args.validation_max_samples, |
| "validation_noise_seed": training_args.validation_noise_seed, |
| "apply_chat_template": training_args.apply_chat_template, |
| "apply_qwen_2_5_vl_pos_emb": training_args.apply_qwen_2_5_vl_pos_emb, |
| "image_token_id": image_token_id, |
| "val_packed_vit_token_indexes": val_data.get("packed_vit_token_indexes", None), |
| "val_packed_vit_tokens": val_data.get("packed_vit_tokens", None), |
| "vit_video_grid_thw": val_data.get("vit_video_grid_thw", None), |
| "vae_video_grid_thw": val_data["vae_video_grid_thw"], |
| "video_grid_thw": val_data.get("video_grid_thw", None), |
| "caption": val_data.get("caption", None), |
| "sample_task": val_data["sample_task"], |
| "sample_modality": val_data["sample_modality"], |
| "cfg_type": training_args.cfg_type, |
| "cfg_uncond_token_id": training_args.cfg_uncond_token_id, |
| "index": val_data["index"], |
| "val_padded_videos": val_data["padded_videos"] if save_source_video else None, |
| } |
| if inference_args.use_KVcache: |
| denoise_latent, captions, padded_videos, index = fsdp_model.validation_gen_KVcache(**params) |
| else: |
| denoise_latent, captions, padded_videos, index = fsdp_model.validation_gen(**params) |
|
|
| |
| for i_val, latent in enumerate(denoise_latent): |
| if inference_args.task in {TASK_IMAGE_EDIT, TASK_VIDEO_EDIT}: |
| target_latents = [latent[-1]] |
| else: |
| target_latents = latent |
|
|
| v_list = [] |
| for latent_ in target_latents: |
| v_list.append(vae_model.vae_decode([latent_])[0]) |
|
|
| save_item_name = f"{index:06d}" if isinstance(index, int) else index |
| v_thwc = decode_video_tensor(v_list, save_path=save_path_gen, save_half=False, save_item_name=save_item_name) |
|
|
| if v_thwc.shape[0] > 1: |
| prompt_data_path = f"{save_item_name}.mp4" |
| else: |
| prompt_data_path = f"{save_item_name}.png" |
| inference_args.prompt_data_dict[prompt_data_path] = captions[i_val] |
|
|
| if save_source_video: |
| curr_padded_videos = padded_videos[i_val * 2 : (i_val + 1) * 2] |
| v_thwc_gt = decode_video_tensor(curr_padded_videos[-1:], save_path=save_path_gt, save_item_name=save_item_name) |
| del curr_padded_videos, v_thwc_gt |
|
|
| del v_list, v_thwc, latent, target_latents |
| clean_memory() |
|
|
| del denoise_latent, captions, padded_videos, params |
| clean_memory() |
|
|
| elif inference_args.task in UNDERSTANDING_TASKS: |
| generated_sequence_all, captions, index = fsdp_model.validation_video_to_text( |
| val_packed_text_ids=val_data["packed_text_ids"], |
| val_packed_text_indexes=val_data["packed_text_indexes"], |
| val_packed_position_ids=val_data["packed_position_ids"], |
| val_sample_N_target=val_data["sample_N_target"], |
| val_split_lens=val_data["split_lens"], |
| val_attn_modes=val_data["attn_modes"], |
| val_sample_lens=val_data["sample_lens"], |
| val_sample_type=val_data["sample_type"], |
| val_packed_vit_tokens=val_data["packed_vit_tokens"], |
| val_vit_video_grid_thw=val_data["vit_video_grid_thw"], |
| val_ce_loss_indexes=val_data["ce_loss_indexes"], |
| max_samples=training_args.validation_max_samples, |
| max_length=MAX_GENERATION_LENGTH, |
| device=device, |
| dtype=torch.bfloat16, |
| new_token_ids=new_token_ids, |
| pad_token_id=tokenizer.pad_token_id, |
| vocab_size=len(tokenizer), |
| caption=val_data.get("caption_cn", None), |
| tokenizer=tokenizer, |
| apply_chat_template=training_args.apply_chat_template, |
| apply_qwen_2_5_vl_pos_emb=training_args.apply_qwen_2_5_vl_pos_emb, |
| do_sample=False, |
| image_token_id=image_token_id, |
| index=val_data["index"], |
| ) |
|
|
| for i_val, generated_sequence in enumerate(generated_sequence_all): |
| cap = tokenizer.decode(generated_sequence[:, 0]) |
| |
| inference_args.prompt_data_dict[index] = f"{cap}" |
| del generated_sequence |
|
|
| del generated_sequence_all, captions |
| clean_memory() |
|
|
| del val_data |
| clean_memory() |
|
|
|
|
| def main(): |
| |
| assert torch.cuda.is_available() |
| if "RANK" in os.environ and "WORLD_SIZE" in os.environ: |
| dist.init_process_group("nccl") |
| GLOBAL_RANK = dist.get_rank() |
| WORLD_SIZE = dist.get_world_size() |
| else: |
| GLOBAL_RANK = 0 |
| WORLD_SIZE = 1 |
|
|
| LOCAL_RANK = GLOBAL_RANK % torch.cuda.device_count() |
| DEVICE = LOCAL_RANK |
| torch.cuda.set_device(DEVICE) |
|
|
| |
| parser = HfArgumentParser((ModelArguments, DataArguments, InferenceArguments)) |
| model_args, data_args, inference_args = cast( |
| Tuple[ModelArguments, DataArguments, InferenceArguments], |
| parser.parse_args_into_dataclasses(), |
| ) |
| training_args = inference_args |
|
|
| |
| apply_inference_defaults(model_args, data_args, inference_args) |
| training_args.validation_noise_seed = training_args.validation_data_seed |
|
|
| logger = get_logger() |
| log_rank0 = print if GLOBAL_RANK == 0 else (lambda *_: None) |
|
|
| def log_stage(stage_name: str, start_time: float, extra: str = ""): |
| elapsed = time.perf_counter() - start_time |
| suffix = f" | {extra}" if extra else "" |
| log_rank0(f"[startup] {stage_name} done in {elapsed:.2f}s{suffix}") |
|
|
| |
| seed = training_args.global_seed * WORLD_SIZE + GLOBAL_RANK |
| set_seed(seed) |
|
|
| |
| stage_start = time.perf_counter() |
| log_rank0(f"[startup] Loading LLM config: {osp.join(model_args.model_path, 'llm_config.json')}") |
| llm_config: Qwen2Config = Qwen2Config.from_json_file(osp.join(model_args.model_path, "llm_config.json")) |
| log_stage("LLM config load", stage_start) |
|
|
| llm_config.layer_module = model_args.layer_module |
| llm_config.qk_norm = model_args.llm_qk_norm |
| llm_config.qk_norm_und = model_args.llm_qk_norm_und |
| llm_config.qk_norm_gen = model_args.llm_qk_norm_gen |
|
|
| llm_config.tie_word_embeddings = model_args.tie_word_embeddings |
| llm_config.freeze_und = training_args.freeze_und |
| llm_config.apply_qwen_2_5_vl_pos_emb = training_args.apply_qwen_2_5_vl_pos_emb |
|
|
| stage_start = time.perf_counter() |
| log_rank0(f"[startup] Initializing LLM weights: {model_args.model_path}") |
| language_model: Qwen2ForCausalLM = Qwen2ForCausalLM(llm_config) |
| log_stage("LLM weight init", stage_start) |
|
|
| if training_args.visual_und: |
| if model_args.vit_type in ("qwen2_5_vl", "qwen_2_5_vl_original"): |
| stage_start = time.perf_counter() |
| log_rank0(f"[startup] Loading VIT config: {model_args.vit_path}") |
| vit_config = Qwen2_5_VLVisionConfig.from_pretrained(model_args.vit_path) |
| log_stage("VIT config load", stage_start) |
|
|
| stage_start = time.perf_counter() |
| log_rank0(f"[startup] Loading VIT weights: {osp.join(model_args.vit_path, 'vit.safetensors')}") |
| vit_model = Qwen2_5_VisionTransformerPretrainedModel(vit_config) |
| vit_weights = load_file(osp.join(model_args.vit_path, "vit.safetensors")) |
| vit_model.load_state_dict(vit_weights, strict=True) |
| log_stage("VIT weight load", stage_start) |
| else: |
| raise ValueError(f"Unsupported vit_type: {model_args.vit_type}") |
|
|
| clean_memory(vit_weights) |
|
|
| if training_args.visual_gen: |
| stage_start = time.perf_counter() |
| log_rank0("[startup] Initializing VAE") |
| vae_model = WanVideoVAE() |
| vae_config: AutoEncoderParams = deepcopy(vae_model.vae_config) |
| log_stage("VAE init", stage_start) |
| else: |
| vae_model = None |
| vae_config = None |
|
|
| |
| config = LanceConfig( |
| visual_gen=training_args.visual_gen, |
| visual_und=training_args.visual_und, |
| llm_config=llm_config, |
| vit_config=vit_config if training_args.visual_und else None, |
| vae_config=vae_config if training_args.visual_gen else None, |
| latent_patch_size=model_args.latent_patch_size, |
| max_num_frames=model_args.max_num_frames, |
| max_latent_size=model_args.max_latent_size, |
| vit_max_num_patch_per_side=model_args.vit_max_num_patch_per_side, |
| connector_act=model_args.connector_act, |
| interpolate_pos=model_args.interpolate_pos, |
| timestep_shift=training_args.timestep_shift, |
| ) |
| model: Lance = Lance( |
| language_model=language_model, |
| vit_model=vit_model if training_args.visual_und else None, |
| vit_type=model_args.vit_type, |
| config=config, |
| training_args=training_args, |
| ) |
| stage_start = time.perf_counter() |
| log_rank0(f"[startup] Moving Lance model to GPU {DEVICE}") |
| model = model.to(DEVICE) |
| log_stage("Lance model move to GPU", stage_start) |
|
|
| |
| stage_start = time.perf_counter() |
| log_rank0(f"[startup] Loading tokenizer: {model_args.model_path}") |
| tokenizer: Qwen2Tokenizer = Qwen2Tokenizer.from_pretrained(model_args.model_path) |
|
|
| tokenizer, new_token_ids, num_new_tokens = add_special_tokens(tokenizer) |
| log_stage("tokenizer load and special token init", stage_start, extra=f"num_new_tokens={num_new_tokens}") |
|
|
| |
| if training_args.copy_init_moe: |
| language_model.init_moe() |
|
|
| init_from_model_path_if_needed(model, model_args) |
|
|
| |
| if num_new_tokens > 0: |
| model.language_model.resize_token_embeddings(len(tokenizer)) |
| model.config.llm_config.vocab_size = len(tokenizer) |
| model.language_model.config.vocab_size = len(tokenizer) |
|
|
| if model_args.vit_type.lower() == "qwen2_5_vl": |
| from common.model.hacks import hack_qwen2_5_vl_config |
| language_model = hack_qwen2_5_vl_config(language_model) |
|
|
| image_token_id = language_model.config.video_token_id |
| new_token_ids.update({"image_token_id": image_token_id}) |
| model.update_tokenizer(tokenizer=tokenizer) |
|
|
| if model_args.tie_word_embeddings: |
| |
| model.language_model.untie_lm_head() |
| model.language_model.copy_new_token_rows_to_lm_head(num_new_tokens) |
|
|
| |
| model_args.tie_word_embeddings = False |
| llm_config.tie_word_embeddings = False |
| else: |
| assert model.language_model.get_input_embeddings().weight.data.data_ptr() != model.language_model.get_output_embeddings().weight.data.data_ptr(), 'tie_word_embeddings conflict' |
|
|
| model = model.to(device=DEVICE, dtype=torch.bfloat16) |
| model.eval() |
| if vae_model is not None and hasattr(vae_model, "eval"): |
| vae_model.eval() |
|
|
| |
| stage_start = time.perf_counter() |
| log_rank0(f"[startup] Loading dataset config and validation set: {data_args.val_dataset_config_file}") |
| dataset_config = DataConfig.from_yaml(data_args.val_dataset_config_file) |
|
|
| |
| if training_args.visual_und: |
| dataset_config.vit_patch_size = model_args.vit_patch_size |
| dataset_config.vit_patch_size_temporal = model_args.vit_patch_size_temporal |
| dataset_config.vit_max_num_patch_per_side = model_args.vit_max_num_patch_per_side |
| |
| if training_args.visual_gen: |
| assert len(model_args.latent_patch_size) == 3, "len(latent_patch_size) must be 3" |
| vae_downsample = tuple_mul( |
| model_args.latent_patch_size, (vae_config.downsample_temporal, vae_config.downsample_spatial, vae_config.downsample_spatial) |
| ) |
| dataset_config.latent_patch_size = model_args.latent_patch_size |
| dataset_config.vae_downsample = vae_downsample |
| dataset_config.max_latent_size = model_args.max_latent_size |
| dataset_config.max_num_frames = model_args.max_num_frames |
|
|
| |
| dataset_config.text_cond_dropout_prob = model_args.text_cond_dropout_prob |
| dataset_config.vae_cond_dropout_prob = model_args.vae_cond_dropout_prob |
| dataset_config.vit_cond_dropout_prob = model_args.vit_cond_dropout_prob |
|
|
| |
| dataset_config.num_frames = inference_args.num_frames |
| dataset_config.H = inference_args.video_height |
| dataset_config.W = inference_args.video_width |
| dataset_config.task = inference_args.task |
| dataset_config.resolution = inference_args.resolution |
| dataset_config.text_template = inference_args.text_template |
| val_dataset = ValidationDataset( |
| jsonl_path= data_args.val_dataset_config_file, |
| tokenizer=tokenizer, |
| data_args=data_args, |
| model_args=model_args, |
| training_args=training_args, |
| new_token_ids=new_token_ids, |
| dataset_config=dataset_config, |
| local_rank=GLOBAL_RANK, |
| world_size=WORLD_SIZE, |
| ) |
|
|
| val_loader = DataLoader( |
| val_dataset, |
| batch_size=1, |
| num_workers=0, |
| pin_memory=True, |
| collate_fn=simple_custom_collate, |
| drop_last=True, |
| prefetch_factor=None, |
| persistent_workers=False, |
| multiprocessing_context=None, |
| ) |
| log_stage("validation set and DataLoader init", stage_start, extra=f"dataset_size={len(val_dataset)}") |
|
|
| |
| val_loader_iter = iter(val_loader) |
|
|
| |
| if not hasattr(inference_args, "prompt_data_dict"): |
| inference_args.prompt_data_dict = {} |
|
|
| if not os.path.exists(inference_args.save_path_gen): |
| os.makedirs(inference_args.save_path_gen) |
|
|
| for epoch in trange(len(val_loader), desc="Validating", unit="batch", leave=True, ncols=80, disable=(GLOBAL_RANK != 0)): |
| try: |
| val_data_cpu = next(val_loader_iter) |
| except StopIteration: |
| break |
|
|
| validate_on_fixed_batch( |
| fsdp_model=model, |
| vae_model=vae_model, |
| tokenizer=tokenizer, |
| val_data_cpu=val_data_cpu, |
| training_args=training_args, |
| model_args=model_args, |
| inference_args=inference_args, |
| new_token_ids=new_token_ids, |
| image_token_id=image_token_id, |
| device=DEVICE, |
| save_source_video=False, |
| save_path_gen=inference_args.save_path_gen, |
| save_path_gt="", |
| ) |
| del val_data_cpu |
| clean_memory() |
|
|
| |
| if dist.is_initialized(): |
| dist.barrier() |
| gathered = [None for _ in range(dist.get_world_size())] |
| dist.all_gather_object(gathered, inference_args.prompt_data_dict) |
|
|
| if GLOBAL_RANK == 0: |
| merged = {} |
| for d in gathered: |
| merged.update(d) |
| inference_args.prompt_data_dict = merged |
| save_prompt_results(inference_args.prompt_data_dict, inference_args.save_path_gen, logger) |
| if inference_args.task in UNDERSTANDING_TASKS: |
| save_understanding_results( |
| prompt_data_dict=inference_args.prompt_data_dict, |
| dataset_config_file=data_args.val_dataset_config_file, |
| save_path_gen=inference_args.save_path_gen, |
| ) |
|
|
| elif GLOBAL_RANK == 0: |
| save_prompt_results(inference_args.prompt_data_dict, inference_args.save_path_gen, logger) |
| if inference_args.task in UNDERSTANDING_TASKS: |
| save_understanding_results( |
| prompt_data_dict=inference_args.prompt_data_dict, |
| dataset_config_file=data_args.val_dataset_config_file, |
| save_path_gen=inference_args.save_path_gen, |
| ) |
|
|
| if dist.is_initialized(): |
| dist.destroy_process_group() |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|