import os from typing import Dict, Tuple, Optional import torch import torch.distributed as dist from torch import nn, Tensor from transformers import PreTrainedModel, AutoModelForCausalLM, AutoConfig from peft import LoraConfig, get_peft_model, PeftModel from src.model.processor import QWEN2_5_VL_TOKENSELECTION from src.arguments import ModelArguments, TrainingArguments from src.model.processor import LLAVA_NEXT, QWEN2_VL, PHI3V, get_backbone_name, print_master, QWEN2_5_VL, \ backbone2model, QWEN2_VL_TOKENSELECTION, QWEN2_5_VL_TOKENSELECTION, E5_V from src.arguments import ModelArguments from src.model.processor import LLAVA_NEXT, QWEN2_VL, PHI3V, get_backbone_name, print_master, QWEN2_5_VL, INTERNVIDEO2, \ QWEN2_VL_TOKENSELECTION, backbone2model, GME, VLM_IMAGE_TOKENS, LamRA, LamRA_QWEN2_5, COLPALI from src.model.baseline_backbone.colpali import ColPali from src.model.baseline_backbone.gme.gme_inference import GmeQwen2VL from src.model.baseline_backbone.lamra.lamra_inference import LamRAQwen2VL from src.model.baseline_backbone.lamra.lamra_qwen25_inference import LamRAQwen25VL from src.model.baseline_backbone.phi3_v.modeling_phi3_v import Phi3VForCausalLM from src.model.baseline_backbone.llava_next import LlavaNextForConditionalGeneration from transformers import modeling_utils if not hasattr(modeling_utils, "ALL_PARALLEL_STYLES") or modeling_utils.ALL_PARALLEL_STYLES is None: modeling_utils.ALL_PARALLEL_STYLES = ["tp", "none", "colwise", 'rowwise'] try: from safetensors.torch import load_file as safe_load print('safetensors.torch') except Exception: safe_load = None # 在文件内(例如 MMEBModel 前)新增一个 Token-wise MLP class TokenWiseMLP(nn.Module): def __init__(self, hidden_size: int, mlp_hidden_size: Optional[int] = None, dropout: float = 0.1): super().__init__() h = mlp_hidden_size or hidden_size self.proj = nn.Sequential( nn.Linear(hidden_size, h), nn.GELU(), nn.Dropout(dropout), nn.Linear(h, hidden_size), ) def forward(self, x: torch.Tensor) -> torch.Tensor: # x: [B, T, H] B, T, H = x.shape y = self.proj(x.reshape(B * T, H)).reshape(B, T, H) return y class MMEBModel(nn.Module): TRANSFORMER_CLS = AutoModelForCausalLM def __init__(self, encoder: PreTrainedModel, pooling: str = 'last', normalize: bool = False, temperature: float = 0.02, ): super().__init__() self.config = encoder.config self.encoder = encoder self.pooling = pooling self.normalize = normalize self.temperature = temperature self.cross_entropy = nn.CrossEntropyLoss(reduction='mean') self.is_ddp = dist.is_initialized() and dist.get_world_size() > 1 self.process_rank = dist.get_rank() if self.is_ddp else 0 self.world_size = dist.get_world_size() if self.is_ddp else 1 # 早退出相关:默认关闭,构建/加载时按 ModelArguments 决定是否开启 self.enable_early_mlp = False self.early_layer_index = 20 # 第20层 self.early_loss_weight = 0.0 # 第二阶段如需联合loss,这里可>0;第一阶段只训early会在trainer里走compute_early_only def add_early_mlp(self, layer_index: int = 20, mlp_hidden_size: Optional[int] = None, dropout: float = 0.1): # 将 early_mlp 挂到 encoder 下,保证保存/加载时与encoder一起处理 hidden_size = getattr(self.config, "hidden_size", None) if hidden_size is None: raise ValueError("config.hidden_size 未找到,无法初始化 early MLP") self.enable_early_mlp = True self.early_layer_index = int(layer_index) # 将模块挂在 encoder 下,state_dict key 会是 encoder.early_mlp.* self.encoder.early_mlp = TokenWiseMLP(hidden_size, mlp_hidden_size, dropout) # 将配置写入 encoder.config,保存时会带进 config.json setattr(self.encoder.config, "enable_early_mlp", True) setattr(self.encoder.config, "early_layer_index", self.early_layer_index) setattr(self.encoder.config, "early_mlp_hidden_size", mlp_hidden_size or hidden_size) setattr(self.encoder.config, "early_mlp_dropout", dropout) def _apply_early(self, hidden: torch.Tensor) -> torch.Tensor: # 对第 early_layer_index 层的 token hidden 过 MLP 用于早退出 if not self.enable_early_mlp or not hasattr(self.encoder, "early_mlp"): raise RuntimeError("early_mlp 未启用,无法计算早退出表示") return self.encoder.early_mlp(hidden) def encode_input(self, input, return_early: bool = False) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """ 返回: - 当 return_early=False: 只返回 final_pooled - 当 return_early=True: 返回 (final_pooled, early_pooled) """ if getattr(self, "model_backbone", None) == INTERNVIDEO2: if "input_ids" in input.keys(): # text side text_output = self.encoder.get_text_encoder()( input["input_ids"], attention_mask=input["attention_mask"], return_dict=True, mode="text", ) text_embeds = text_output.last_hidden_state pooled_text_embeds = text_embeds[:, 0] pooled_output = self.encoder.text_proj(pooled_text_embeds) pooled_output /= pooled_output.norm(dim=-1, keepdim=True) return pooled_output else: _, vfeat = self.encoder.encode_vision(input["pixel_values"], test=True) vfeat = self.encoder.vision_proj(vfeat) vfeat /= vfeat.norm(dim=-1, keepdim=True) return vfeat elif getattr(self, "model_backbone", None) in [GME, LamRA, LamRA_QWEN2_5]: # pooled_output = self.encoder(**input, return_dict=True, output_hidden_states=True) texts = [text.replace(VLM_IMAGE_TOKENS[QWEN2_VL] + '\n', '') for text in input["texts"]] # we are actually passing video queries so this should not happen images = [] for imgs in input['images']: # if multi images are given, select the middle frame only if isinstance(imgs, list): imgs = imgs[len(imgs) // 2] assert not isinstance(imgs, list) # make sure we have extracted the middle frame and it is no longer a list images.append(imgs) else: images.append(imgs) pooled_output = self.encoder.get_fused_embeddings(texts=texts, images=images) return pooled_output elif getattr(self, "model_backbone", None) == COLPALI: pooled_output = self.encoder(**input, return_dict=True, output_hidden_states=True) return pooled_output elif getattr(self, "model_backbone", None) == LLAVA_NEXT: input['pixel_values'] = input['pixel_values'].squeeze(dim=1) input['image_sizes'] = input['image_sizes'].squeeze(dim=1) hidden_states = self.encoder(**input, return_dict=True, output_hidden_states=True) hidden_states = hidden_states.hidden_states[-1] pooled_output = self._pooling(hidden_states, input['attention_mask']) return pooled_output else: # 泛化:大多数HF模型这里都能拿到 hidden_states outputs = self.encoder(**input, return_dict=True, output_hidden_states=True) hidden_states = outputs.hidden_states final_hidden = hidden_states[-1] final_pooled = self._pooling(final_hidden, input['attention_mask']) if return_early and self.enable_early_mlp: idx = min(max(0, self.early_layer_index), len(hidden_states) - 1) early_hidden = hidden_states[idx] early_hidden = self._apply_early(early_hidden) early_pooled = self._pooling(early_hidden, input['attention_mask']) return final_pooled, early_pooled return final_pooled def _pooling(self, last_hidden_state, attention_mask): if self.pooling == 'last' or self.pooling == 'eos': left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0]) if attention_mask is not None else False batch_size = last_hidden_state.shape[0] if attention_mask is None: reps = last_hidden_state[:, -1, :] elif left_padding: reps = last_hidden_state[torch.arange(batch_size), -1, :] else: eos_indices = attention_mask.sum(dim=1) - 1 reps = last_hidden_state[ torch.arange(batch_size, device=last_hidden_state.device), eos_indices] else: raise NotImplementedError if self.normalize: reps = torch.nn.functional.normalize(reps, p=2, dim=-1) return reps @classmethod def build(cls, model_args: ModelArguments, **kwargs): config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True) model_backbone = get_backbone_name(hf_config=config) print_master(f'Loading backbone [{model_backbone}] from {model_args.model_name}') # Loading the base model if model_backbone == PHI3V: config._attn_implementation = "eager" config.padding_side = "right" config.use_cache = False base_model = Phi3VForCausalLM.from_pretrained( model_args.model_name, config=config, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, ) elif model_backbone == LLAVA_NEXT: config.use_cache = False config.padding_side = "left" base_model = LlavaNextForConditionalGeneration.from_pretrained( model_args.model_name, config=config, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, ) elif model_backbone in [QWEN2_VL, QWEN2_5_VL]: config._attn_implementation = "flash_attention_2" config.padding_side = "left" config.use_cache = False base_model = backbone2model[model_backbone].from_pretrained( model_args.model_name, config=config, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, ) elif model_backbone in [QWEN2_VL_TOKENSELECTION, QWEN2_5_VL_TOKENSELECTION]: config._attn_implementation = "flash_attention_2" config.padding_side = "left" config.use_cache = False from .utils import parse_layer_type lm_qwen_layer = 28 vis_qwen_layer = 32 lm_skip_layer = parse_layer_type(model_args.lm_skip_layer, lm_qwen_layer) vis_skip_layer = parse_layer_type(model_args.vis_skip_layer, vis_qwen_layer) base_model = backbone2model[model_backbone].from_pretrained( model_args.model_name, config=config, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, lm_skip_layer=lm_skip_layer, vis_skip_layer=vis_skip_layer, ) else: config.use_cache = False base_model = cls.TRANSFORMER_CLS.from_pretrained( model_args.model_name, **kwargs, config=config, attn_implementation="flash_attention_2", torch_dtype=torch.bfloat16, trust_remote_code=True) # if model_args.lora: # print_master(f'Loading lora adapter from {base_model}') # lora_config = LoraConfig( # r=model_args.lora_r, # lora_alpha=model_args.lora_alpha, # target_modules=model_args.lora_target_modules.split(','), # lora_dropout=model_args.lora_dropout, # init_lora_weights="gaussian", # use_dora=True, # inference_mode=False # ) # lora_model = get_peft_model(base_model, lora_config) # model = cls( # encoder=lora_model, # pooling=model_args.pooling, # normalize=model_args.normalize, # temperature=model_args.temperature # ) # else: # model = cls( # encoder=base_model, # pooling=model_args.pooling, # normalize=model_args.normalize, # temperature=model_args.temperature # ) if model_args.lora: # 修改开始:优先从 model_args.model_name 加载已有LoRA;否则才新建 def _has_adapter_files(path: str): if not (path and os.path.isdir(path)): return False for fname in ("adapter_model.safetensors", "adapter_model.bin", "adapter_config.json"): if os.path.exists(os.path.join(path, fname)): return True return False if _has_adapter_files(model_args.model_name): print_master(f"[build] detected LoRA adapter in '{model_args.model_name}', loading pretrained adapter.") lora_config = LoraConfig.from_pretrained(model_args.model_name) lora_model = PeftModel.from_pretrained( base_model, model_args.model_name, config=lora_config, is_trainable=True ) # 一般不需要 merge,这里保持Adapter在线,后续freeze/训练由外层控制 else: print_master(f"[build] no adapter files in '{model_args.model_name}', create a new LoRA adapter.") lora_config = LoraConfig( r=model_args.lora_r, lora_alpha=model_args.lora_alpha, target_modules=model_args.lora_target_modules.split(','), lora_dropout=model_args.lora_dropout, init_lora_weights="gaussian", use_dora=True, inference_mode=False ) lora_model = get_peft_model(base_model, lora_config) model = cls( encoder=lora_model, pooling=model_args.pooling, normalize=model_args.normalize, temperature=model_args.temperature ) # 修改结束 else: model = cls( encoder=base_model, pooling=model_args.pooling, normalize=model_args.normalize, temperature=model_args.temperature ) # 早退出 MLP(可选) if getattr(model_args, "enable_early_mlp", False): model.add_early_mlp( layer_index=getattr(model_args, "early_layer_index", 20), mlp_hidden_size=getattr(model_args, "early_mlp_hidden_size", None), dropout=getattr(model_args, "early_mlp_dropout", 0.1), ) # 第二阶段如需联合loss,可以通过 model_args.early_loss_weight 传入 model.early_loss_weight = float(getattr(model_args, "early_loss_weight", 0.0)) model.model_backbone = model_backbone return model @classmethod def load(cls, model_args: ModelArguments, is_trainable=True, **kwargs): # Loading the base model model_name_or_path = model_args.checkpoint_path if model_args.checkpoint_path else model_args.model_name config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True) if not hasattr(model_args, "model_backbone") or not model_args.model_backbone: model_backbone = get_backbone_name(hf_config=config, model_type=model_args.model_type) setattr(model_args, 'model_backbone', model_backbone) print_master(f'Loading backbone [{model_args.model_backbone}] from {model_name_or_path}') if model_args.model_backbone in {LLAVA_NEXT, QWEN2_VL, QWEN2_5_VL, QWEN2_VL_TOKENSELECTION, QWEN2_5_VL_TOKENSELECTION, E5_V}: config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True) config._attn_implementation = "flash_attention_2" config.vision_config._attn_implementation = "flash_attention_2" base_model = backbone2model[model_args.model_backbone].from_pretrained( model_args.model_name, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, config=config ) elif model_args.model_backbone == PHI3V: config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True) config.use_cache = False config.padding_side = "right" base_model = Phi3VForCausalLM.from_pretrained(model_args.model_name, **kwargs, config=config, torch_dtype=torch.bfloat16, trust_remote_code=True) base_model.padding_side = "right" elif model_args.model_backbone == INTERNVIDEO2: print_master(f'Loading backbone [{model_args.model_backbone}] from {"src/model/vlm_backbone/internvideo2/"}') config = AutoConfig.from_pretrained("src/model/vlm_backbone/internvideo2/", trust_remote_code=True) base_model = backbone2model[model_args.model_backbone].from_pretrained("src/model/vlm_backbone/internvideo2/", config=config, trust_remote_code=True) elif model_args.model_backbone == GME: base_model = GmeQwen2VL(model_args.model_name, processor=kwargs['processor']) setattr(base_model, 'config', config) elif model_args.model_backbone == LamRA: base_model = LamRAQwen2VL(model_args.model_name) setattr(base_model, 'config', config) elif model_args.model_backbone == LamRA_QWEN2_5: base_model = LamRAQwen25VL(model_args.model_name) setattr(base_model, 'config', config) elif model_args.model_backbone == COLPALI: base_model = ColPali.from_pretrained(model_args.model_name) setattr(base_model, 'config', config) else: # Loading external base model from HF config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True) config.use_cache = False base_model = cls.TRANSFORMER_CLS.from_pretrained( model_name_or_path, **kwargs, config=config, torch_dtype=torch.bfloat16, trust_remote_code=True) # Building the model on top of the base if model_args.lora: print_master(f'Loading LoRA from {model_name_or_path}') # lora_config = LoraConfig.from_pretrained(model_name_or_path) # lora_model = PeftModel.from_pretrained(base_model, model_name_or_path, config=lora_config, is_trainable=is_trainable) # lora_model.load_adapter(model_name_or_path, lora_model.active_adapter, is_trainable=is_trainable) # if not is_trainable: # lora_model = lora_model.merge_and_unload() # model = cls( # encoder=lora_model, # pooling=model_args.pooling, # normalize=model_args.normalize, # temperature=model_args.temperature # ) # 修改开始:检测 adapter 是否存在,不存在则回落到 model_args.model_name def _has_adapter_files(path: str): if not (path and os.path.isdir(path)): return False for fname in ("adapter_model.safetensors", "adapter_model.bin", "adapter_config.json"): if os.path.exists(os.path.join(path, fname)): return True return False adapter_source = model_name_or_path if _has_adapter_files(model_name_or_path) else model_args.model_name if adapter_source != model_name_or_path: print_master(f"[load] adapter files not found in '{model_name_or_path}', fallback to '{adapter_source}'") # 某些 peft 版本要求 config 单独加载 lora_config = LoraConfig.from_pretrained(adapter_source) lora_model = PeftModel.from_pretrained(base_model, adapter_source, config=lora_config, is_trainable=is_trainable) lora_model.load_adapter(adapter_source, lora_model.active_adapter, is_trainable=is_trainable) if not is_trainable: lora_model = lora_model.merge_and_unload() model = cls( encoder=lora_model, pooling=model_args.pooling, normalize=model_args.normalize, temperature=model_args.temperature ) # 修改结束 else: model = cls( encoder=base_model, pooling=model_args.pooling, normalize=model_args.normalize, temperature=model_args.temperature ) model.model_backbone = model_args.model_backbone # 早退出 MLP:按配置/检查点信息启用,并尝试从checkpoint加载early_mlp权重 enable_early = bool(getattr(model_args, "enable_early_mlp", False) or getattr(model.encoder.config, "enable_early_mlp", False)) if enable_early: layer_index = int(getattr(model_args, "early_layer_index", getattr(model.encoder.config, "early_layer_index", 20))) mlp_hidden_size = getattr(model_args, "early_mlp_hidden_size", getattr(model.encoder.config, "early_mlp_hidden_size", None)) dropout = getattr(model_args, "early_mlp_dropout", getattr(model.encoder.config, "early_mlp_dropout", 0.1)) model.add_early_mlp(layer_index=layer_index, mlp_hidden_size=mlp_hidden_size, dropout=dropout) model.early_loss_weight = float(getattr(model_args, "early_loss_weight", 0.0)) # 从checkpoint加载 early_mlp.* 权重 ckpt_dir = model_args.checkpoint_path if model_args.checkpoint_path else model_args.model_name if os.path.isdir(ckpt_dir): pt_path = os.path.join(ckpt_dir, "pytorch_model.bin") st_path = os.path.join(ckpt_dir, "model.safetensors") state = None if os.path.exists(pt_path): state = torch.load(pt_path, map_location="cpu") elif os.path.exists(st_path) and safe_load is not None: state = safe_load(st_path) if state is not None: early_state = { k.replace("early_mlp.", "", 1): v for k, v in state.items() if k.startswith("early_mlp.") } if early_state: missing, unexpected = model.encoder.early_mlp.load_state_dict(early_state, strict=False) print_master( f"Loaded early_mlp weights from {ckpt_dir} (missing={missing}, unexpected={unexpected})" ) # 新增兜底:尝试从 early_mlp.bin 回载 bin_path = os.path.join(ckpt_dir, "early_mlp.bin") if os.path.exists(bin_path): try: early_bin = torch.load(bin_path, map_location="cpu") missing, unexpected = model.encoder.early_mlp.load_state_dict(early_bin, strict=False) print_master( f"Loaded early_mlp weights from early_mlp.bin (missing={missing}, unexpected={unexpected})" ) except Exception as e: print_master(f"Failed to load early_mlp from early_mlp.bin: {e}") return model def save(self, output_dir: str): self.encoder.save_pretrained(output_dir) def forward(self, qry: Dict[str, torch.Tensor] = None, tgt: Dict[str, torch.Tensor] = None, compute_early_only: bool = False, *args, **kwargs): # 修改开始:支持单侧前向(GradCache 取表征时会这样调用) only_q = (qry is not None) and (tgt is None) only_t = (tgt is not None) and (qry is None) if only_q or only_t: # GradCache 的 get_rep_fn 期望的是“最终层”的检索表征;这里不要用 early,只返回 final single = qry if only_q else tgt rep = self.encode_input(single, return_early=False) # [B, D] return {"qry_reps": rep if only_q else None, "tgt_reps": rep if only_t else None} # 如果两边都没传,说明 split_and_process 有问题,给出更友好的报错 if qry is None and tgt is None: raise ValueError("MMEBModel.forward expected 'qry' and/or 'tgt' but got none. " "Check split_and_process_vlm_inputs / training_step packaging.") # 修改结束 # 下面是双侧(正常计算loss)的路径 if compute_early_only: if not self.enable_early_mlp: raise RuntimeError("compute_early_only=True 但 early_mlp 未启用") qry_final, qry_early = self.encode_input(qry, return_early=True) tgt_final, tgt_early = self.encode_input(tgt, return_early=True) assert qry_early is not None and tgt_early is not None qry_reps, tgt_reps = qry_early, tgt_early else: if self.enable_early_mlp and (self.training and self.early_loss_weight > 0): qry_final, qry_early = self.encode_input(qry, return_early=True) tgt_final, tgt_early = self.encode_input(tgt, return_early=True) else: qry_final = self.encode_input(qry, return_early=False) tgt_final = self.encode_input(tgt, return_early=False) qry_early, tgt_early = None, None qry_reps, tgt_reps = qry_final, tgt_final # DDP gather if self.is_ddp: all_qry_reps = self._dist_gather_tensor(qry_reps) all_tgt_reps = self._dist_gather_tensor(tgt_reps) else: all_qry_reps = qry_reps all_tgt_reps = tgt_reps # 主loss scores = self.compute_similarity(all_qry_reps, all_tgt_reps).view(all_qry_reps.size(0), -1) target = torch.arange(scores.size(0), device=scores.device, dtype=torch.long) target = target * (all_qry_reps.size(0) // all_tgt_reps.size(0)) loss = self.cross_entropy(scores / self.temperature, target) if self.is_ddp: loss = loss * self.world_size # 辅助 early loss(可选) if (not compute_early_only) and self.training and self.enable_early_mlp and self.early_loss_weight > 0: assert qry_early is not None and tgt_early is not None if self.is_ddp: all_qry_e = self._dist_gather_tensor(qry_early) all_tgt_e = self._dist_gather_tensor(tgt_early) else: all_qry_e, all_tgt_e = qry_early, tgt_early scores_e = self.compute_similarity(all_qry_e, all_tgt_e).view(all_qry_e.size(0), -1) target_e = torch.arange(scores_e.size(0), device=scores_e.device, dtype=torch.long) target_e = target_e * (all_qry_e.size(0) // all_tgt_e.size(0)) loss_e = self.cross_entropy(scores_e / self.temperature, target_e) if self.is_ddp: loss_e = loss_e * self.world_size loss = loss + self.early_loss_weight * loss_e return loss def _dist_gather_tensor(self, t: Tensor): t = t.contiguous() all_tensors = [torch.empty_like(t) for _ in range(self.world_size)] dist.all_gather(all_tensors, t) all_tensors[self.process_rank] = t all_tensors = torch.cat(all_tensors, dim=0) return all_tensors def compute_similarity(self, q_reps, p_reps): return torch.matmul(q_reps, p_reps.transpose(0, 1))