| import os |
| from typing import Dict, Tuple, Optional |
| import torch |
| import torch.distributed as dist |
| from torch import nn, Tensor |
| from transformers import PreTrainedModel, AutoModelForCausalLM, AutoConfig |
| from peft import LoraConfig, get_peft_model, PeftModel |
| from src.model.processor import QWEN2_5_VL_TOKENSELECTION |
| from src.arguments import ModelArguments, TrainingArguments |
| from src.model.processor import LLAVA_NEXT, QWEN2_VL, PHI3V, get_backbone_name, print_master, QWEN2_5_VL, \ |
| backbone2model, QWEN2_VL_TOKENSELECTION, QWEN2_5_VL_TOKENSELECTION, E5_V |
|
|
| from src.arguments import ModelArguments |
| from src.model.processor import LLAVA_NEXT, QWEN2_VL, PHI3V, get_backbone_name, print_master, QWEN2_5_VL, INTERNVIDEO2, \ |
| QWEN2_VL_TOKENSELECTION, backbone2model, GME, VLM_IMAGE_TOKENS, LamRA, LamRA_QWEN2_5, COLPALI |
| from src.model.baseline_backbone.colpali import ColPali |
| from src.model.baseline_backbone.gme.gme_inference import GmeQwen2VL |
| from src.model.baseline_backbone.lamra.lamra_inference import LamRAQwen2VL |
| from src.model.baseline_backbone.lamra.lamra_qwen25_inference import LamRAQwen25VL |
| from src.model.baseline_backbone.phi3_v.modeling_phi3_v import Phi3VForCausalLM |
| from src.model.baseline_backbone.llava_next import LlavaNextForConditionalGeneration |
|
|
| from transformers import modeling_utils |
| if not hasattr(modeling_utils, "ALL_PARALLEL_STYLES") or modeling_utils.ALL_PARALLEL_STYLES is None: |
| modeling_utils.ALL_PARALLEL_STYLES = ["tp", "none", "colwise", 'rowwise'] |
| try: |
| from safetensors.torch import load_file as safe_load |
| print('safetensors.torch') |
| except Exception: |
| safe_load = None |
|
|
| |
| class TokenWiseMLP(nn.Module): |
| def __init__(self, hidden_size: int, mlp_hidden_size: Optional[int] = None, dropout: float = 0.1): |
| super().__init__() |
| h = mlp_hidden_size or hidden_size |
| self.proj = nn.Sequential( |
| nn.Linear(hidden_size, h), |
| nn.GELU(), |
| nn.Dropout(dropout), |
| nn.Linear(h, hidden_size), |
| ) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| |
| B, T, H = x.shape |
| y = self.proj(x.reshape(B * T, H)).reshape(B, T, H) |
| return y |
|
|
| class MMEBModel(nn.Module): |
| TRANSFORMER_CLS = AutoModelForCausalLM |
|
|
| def __init__(self, |
| encoder: PreTrainedModel, |
| pooling: str = 'last', |
| normalize: bool = False, |
| temperature: float = 0.02, |
| ): |
| super().__init__() |
| self.config = encoder.config |
| self.encoder = encoder |
| self.pooling = pooling |
| self.normalize = normalize |
| self.temperature = temperature |
| self.cross_entropy = nn.CrossEntropyLoss(reduction='mean') |
| self.is_ddp = dist.is_initialized() and dist.get_world_size() > 1 |
| self.process_rank = dist.get_rank() if self.is_ddp else 0 |
| self.world_size = dist.get_world_size() if self.is_ddp else 1 |
|
|
| |
| self.enable_early_mlp = False |
| self.early_layer_index = 20 |
| self.early_loss_weight = 0.0 |
|
|
| def add_early_mlp(self, layer_index: int = 20, |
| mlp_hidden_size: Optional[int] = None, |
| dropout: float = 0.1): |
| |
| hidden_size = getattr(self.config, "hidden_size", None) |
| if hidden_size is None: |
| raise ValueError("config.hidden_size 未找到,无法初始化 early MLP") |
| self.enable_early_mlp = True |
| self.early_layer_index = int(layer_index) |
| |
| self.encoder.early_mlp = TokenWiseMLP(hidden_size, mlp_hidden_size, dropout) |
| |
| setattr(self.encoder.config, "enable_early_mlp", True) |
| setattr(self.encoder.config, "early_layer_index", self.early_layer_index) |
| setattr(self.encoder.config, "early_mlp_hidden_size", mlp_hidden_size or hidden_size) |
| setattr(self.encoder.config, "early_mlp_dropout", dropout) |
|
|
| def _apply_early(self, hidden: torch.Tensor) -> torch.Tensor: |
| |
| if not self.enable_early_mlp or not hasattr(self.encoder, "early_mlp"): |
| raise RuntimeError("early_mlp 未启用,无法计算早退出表示") |
| return self.encoder.early_mlp(hidden) |
|
|
| def encode_input(self, input, return_early: bool = False) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: |
| """ |
| 返回: |
| - 当 return_early=False: 只返回 final_pooled |
| - 当 return_early=True: 返回 (final_pooled, early_pooled) |
| """ |
| if getattr(self, "model_backbone", None) == INTERNVIDEO2: |
| if "input_ids" in input.keys(): |
| |
| text_output = self.encoder.get_text_encoder()( |
| input["input_ids"], |
| attention_mask=input["attention_mask"], |
| return_dict=True, |
| mode="text", |
| ) |
| text_embeds = text_output.last_hidden_state |
| pooled_text_embeds = text_embeds[:, 0] |
| pooled_output = self.encoder.text_proj(pooled_text_embeds) |
| pooled_output /= pooled_output.norm(dim=-1, keepdim=True) |
| return pooled_output |
| else: |
| _, vfeat = self.encoder.encode_vision(input["pixel_values"], test=True) |
| vfeat = self.encoder.vision_proj(vfeat) |
| vfeat /= vfeat.norm(dim=-1, keepdim=True) |
| return vfeat |
| elif getattr(self, "model_backbone", None) in [GME, LamRA, LamRA_QWEN2_5]: |
| |
| texts = [text.replace(VLM_IMAGE_TOKENS[QWEN2_VL] + '\n', '') for text in input["texts"]] |
| images = [] |
| for imgs in input['images']: |
| |
| if isinstance(imgs, list): |
| imgs = imgs[len(imgs) // 2] |
| assert not isinstance(imgs, list) |
| images.append(imgs) |
| else: |
| images.append(imgs) |
| pooled_output = self.encoder.get_fused_embeddings(texts=texts, images=images) |
| return pooled_output |
| elif getattr(self, "model_backbone", None) == COLPALI: |
| pooled_output = self.encoder(**input, return_dict=True, output_hidden_states=True) |
| return pooled_output |
| elif getattr(self, "model_backbone", None) == LLAVA_NEXT: |
| input['pixel_values'] = input['pixel_values'].squeeze(dim=1) |
| input['image_sizes'] = input['image_sizes'].squeeze(dim=1) |
| hidden_states = self.encoder(**input, return_dict=True, output_hidden_states=True) |
| hidden_states = hidden_states.hidden_states[-1] |
| pooled_output = self._pooling(hidden_states, input['attention_mask']) |
| return pooled_output |
| else: |
| |
| outputs = self.encoder(**input, return_dict=True, output_hidden_states=True) |
| hidden_states = outputs.hidden_states |
| final_hidden = hidden_states[-1] |
| final_pooled = self._pooling(final_hidden, input['attention_mask']) |
| if return_early and self.enable_early_mlp: |
| idx = min(max(0, self.early_layer_index), len(hidden_states) - 1) |
| early_hidden = hidden_states[idx] |
| early_hidden = self._apply_early(early_hidden) |
| early_pooled = self._pooling(early_hidden, input['attention_mask']) |
| return final_pooled, early_pooled |
| return final_pooled |
|
|
| def _pooling(self, last_hidden_state, attention_mask): |
| if self.pooling == 'last' or self.pooling == 'eos': |
| left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0]) if attention_mask is not None else False |
| batch_size = last_hidden_state.shape[0] |
| if attention_mask is None: |
| reps = last_hidden_state[:, -1, :] |
| elif left_padding: |
| reps = last_hidden_state[torch.arange(batch_size), -1, :] |
| else: |
| eos_indices = attention_mask.sum(dim=1) - 1 |
| reps = last_hidden_state[ |
| torch.arange(batch_size, device=last_hidden_state.device), eos_indices] |
| else: |
| raise NotImplementedError |
| if self.normalize: |
| reps = torch.nn.functional.normalize(reps, p=2, dim=-1) |
| return reps |
|
|
| @classmethod |
| def build(cls, model_args: ModelArguments, **kwargs): |
| config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True) |
| model_backbone = get_backbone_name(hf_config=config) |
| print_master(f'Loading backbone [{model_backbone}] from {model_args.model_name}') |
| |
| if model_backbone == PHI3V: |
| config._attn_implementation = "eager" |
| config.padding_side = "right" |
| config.use_cache = False |
| base_model = Phi3VForCausalLM.from_pretrained( |
| model_args.model_name, |
| config=config, |
| torch_dtype=torch.bfloat16, |
| low_cpu_mem_usage=True, |
| ) |
| elif model_backbone == LLAVA_NEXT: |
| config.use_cache = False |
| config.padding_side = "left" |
| base_model = LlavaNextForConditionalGeneration.from_pretrained( |
| model_args.model_name, |
| config=config, |
| torch_dtype=torch.bfloat16, |
| low_cpu_mem_usage=True, |
| ) |
| elif model_backbone in [QWEN2_VL, QWEN2_5_VL]: |
| config._attn_implementation = "flash_attention_2" |
| config.padding_side = "left" |
| config.use_cache = False |
| base_model = backbone2model[model_backbone].from_pretrained( |
| model_args.model_name, |
| config=config, |
| torch_dtype=torch.bfloat16, |
| low_cpu_mem_usage=True, |
| ) |
| elif model_backbone in [QWEN2_VL_TOKENSELECTION, QWEN2_5_VL_TOKENSELECTION]: |
| config._attn_implementation = "flash_attention_2" |
| config.padding_side = "left" |
| config.use_cache = False |
|
|
| from .utils import parse_layer_type |
| lm_qwen_layer = 28 |
| vis_qwen_layer = 32 |
| lm_skip_layer = parse_layer_type(model_args.lm_skip_layer, lm_qwen_layer) |
| vis_skip_layer = parse_layer_type(model_args.vis_skip_layer, vis_qwen_layer) |
|
|
| base_model = backbone2model[model_backbone].from_pretrained( |
| model_args.model_name, |
| config=config, |
| torch_dtype=torch.bfloat16, |
| low_cpu_mem_usage=True, |
| lm_skip_layer=lm_skip_layer, |
| vis_skip_layer=vis_skip_layer, |
| ) |
| else: |
| config.use_cache = False |
| base_model = cls.TRANSFORMER_CLS.from_pretrained( |
| model_args.model_name, **kwargs, config=config, |
| attn_implementation="flash_attention_2", |
| torch_dtype=torch.bfloat16, |
| trust_remote_code=True) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| if model_args.lora: |
| |
| def _has_adapter_files(path: str): |
| if not (path and os.path.isdir(path)): |
| return False |
| for fname in ("adapter_model.safetensors", "adapter_model.bin", "adapter_config.json"): |
| if os.path.exists(os.path.join(path, fname)): |
| return True |
| return False |
|
|
| if _has_adapter_files(model_args.model_name): |
| print_master(f"[build] detected LoRA adapter in '{model_args.model_name}', loading pretrained adapter.") |
| lora_config = LoraConfig.from_pretrained(model_args.model_name) |
| lora_model = PeftModel.from_pretrained( |
| base_model, model_args.model_name, config=lora_config, is_trainable=True |
| ) |
| |
| else: |
| print_master(f"[build] no adapter files in '{model_args.model_name}', create a new LoRA adapter.") |
| lora_config = LoraConfig( |
| r=model_args.lora_r, |
| lora_alpha=model_args.lora_alpha, |
| target_modules=model_args.lora_target_modules.split(','), |
| lora_dropout=model_args.lora_dropout, |
| init_lora_weights="gaussian", |
| use_dora=True, |
| inference_mode=False |
| ) |
| lora_model = get_peft_model(base_model, lora_config) |
|
|
| model = cls( |
| encoder=lora_model, |
| pooling=model_args.pooling, |
| normalize=model_args.normalize, |
| temperature=model_args.temperature |
| ) |
| |
| else: |
| model = cls( |
| encoder=base_model, |
| pooling=model_args.pooling, |
| normalize=model_args.normalize, |
| temperature=model_args.temperature |
| ) |
|
|
| |
| if getattr(model_args, "enable_early_mlp", False): |
| model.add_early_mlp( |
| layer_index=getattr(model_args, "early_layer_index", 20), |
| mlp_hidden_size=getattr(model_args, "early_mlp_hidden_size", None), |
| dropout=getattr(model_args, "early_mlp_dropout", 0.1), |
| ) |
| |
| model.early_loss_weight = float(getattr(model_args, "early_loss_weight", 0.0)) |
| |
| model.model_backbone = model_backbone |
| return model |
|
|
|
|
| @classmethod |
| def load(cls, model_args: ModelArguments, is_trainable=True, **kwargs): |
| |
| model_name_or_path = model_args.checkpoint_path if model_args.checkpoint_path else model_args.model_name |
| config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True) |
| if not hasattr(model_args, "model_backbone") or not model_args.model_backbone: |
| model_backbone = get_backbone_name(hf_config=config, model_type=model_args.model_type) |
| setattr(model_args, 'model_backbone', model_backbone) |
| print_master(f'Loading backbone [{model_args.model_backbone}] from {model_name_or_path}') |
| if model_args.model_backbone in {LLAVA_NEXT, QWEN2_VL, QWEN2_5_VL, QWEN2_VL_TOKENSELECTION, QWEN2_5_VL_TOKENSELECTION, E5_V}: |
| config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True) |
| config._attn_implementation = "flash_attention_2" |
| config.vision_config._attn_implementation = "flash_attention_2" |
| base_model = backbone2model[model_args.model_backbone].from_pretrained( |
| model_args.model_name, |
| torch_dtype=torch.bfloat16, |
| low_cpu_mem_usage=True, |
| config=config |
| ) |
| elif model_args.model_backbone == PHI3V: |
| config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True) |
| config.use_cache = False |
| config.padding_side = "right" |
| base_model = Phi3VForCausalLM.from_pretrained(model_args.model_name, **kwargs, config=config, |
| torch_dtype=torch.bfloat16, trust_remote_code=True) |
| base_model.padding_side = "right" |
| elif model_args.model_backbone == INTERNVIDEO2: |
| print_master(f'Loading backbone [{model_args.model_backbone}] from {"src/model/vlm_backbone/internvideo2/"}') |
| config = AutoConfig.from_pretrained("src/model/vlm_backbone/internvideo2/", |
| trust_remote_code=True) |
| base_model = backbone2model[model_args.model_backbone].from_pretrained("src/model/vlm_backbone/internvideo2/", config=config, |
| trust_remote_code=True) |
| elif model_args.model_backbone == GME: |
| base_model = GmeQwen2VL(model_args.model_name, processor=kwargs['processor']) |
| setattr(base_model, 'config', config) |
| elif model_args.model_backbone == LamRA: |
| base_model = LamRAQwen2VL(model_args.model_name) |
| setattr(base_model, 'config', config) |
| elif model_args.model_backbone == LamRA_QWEN2_5: |
| base_model = LamRAQwen25VL(model_args.model_name) |
| setattr(base_model, 'config', config) |
| elif model_args.model_backbone == COLPALI: |
| base_model = ColPali.from_pretrained(model_args.model_name) |
| setattr(base_model, 'config', config) |
| else: |
| |
| config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True) |
| config.use_cache = False |
| base_model = cls.TRANSFORMER_CLS.from_pretrained( |
| model_name_or_path, **kwargs, config=config, |
| torch_dtype=torch.bfloat16, |
| trust_remote_code=True) |
|
|
| |
| if model_args.lora: |
| print_master(f'Loading LoRA from {model_name_or_path}') |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| def _has_adapter_files(path: str): |
| if not (path and os.path.isdir(path)): |
| return False |
| for fname in ("adapter_model.safetensors", "adapter_model.bin", "adapter_config.json"): |
| if os.path.exists(os.path.join(path, fname)): |
| return True |
| return False |
| adapter_source = model_name_or_path if _has_adapter_files(model_name_or_path) else model_args.model_name |
| if adapter_source != model_name_or_path: |
| print_master(f"[load] adapter files not found in '{model_name_or_path}', fallback to '{adapter_source}'") |
|
|
| |
| lora_config = LoraConfig.from_pretrained(adapter_source) |
| lora_model = PeftModel.from_pretrained(base_model, adapter_source, config=lora_config, is_trainable=is_trainable) |
| lora_model.load_adapter(adapter_source, lora_model.active_adapter, is_trainable=is_trainable) |
|
|
| if not is_trainable: |
| lora_model = lora_model.merge_and_unload() |
|
|
| model = cls( |
| encoder=lora_model, |
| pooling=model_args.pooling, |
| normalize=model_args.normalize, |
| temperature=model_args.temperature |
| ) |
| |
| else: |
| model = cls( |
| encoder=base_model, |
| pooling=model_args.pooling, |
| normalize=model_args.normalize, |
| temperature=model_args.temperature |
| ) |
|
|
| model.model_backbone = model_args.model_backbone |
|
|
| |
| enable_early = bool(getattr(model_args, "enable_early_mlp", False) or getattr(model.encoder.config, "enable_early_mlp", False)) |
| if enable_early: |
| layer_index = int(getattr(model_args, "early_layer_index", getattr(model.encoder.config, "early_layer_index", 20))) |
| mlp_hidden_size = getattr(model_args, "early_mlp_hidden_size", getattr(model.encoder.config, "early_mlp_hidden_size", None)) |
| dropout = getattr(model_args, "early_mlp_dropout", getattr(model.encoder.config, "early_mlp_dropout", 0.1)) |
| model.add_early_mlp(layer_index=layer_index, mlp_hidden_size=mlp_hidden_size, dropout=dropout) |
| model.early_loss_weight = float(getattr(model_args, "early_loss_weight", 0.0)) |
|
|
| |
| ckpt_dir = model_args.checkpoint_path if model_args.checkpoint_path else model_args.model_name |
| if os.path.isdir(ckpt_dir): |
| pt_path = os.path.join(ckpt_dir, "pytorch_model.bin") |
| st_path = os.path.join(ckpt_dir, "model.safetensors") |
| state = None |
|
|
| if os.path.exists(pt_path): |
| state = torch.load(pt_path, map_location="cpu") |
| elif os.path.exists(st_path) and safe_load is not None: |
| state = safe_load(st_path) |
|
|
| if state is not None: |
| early_state = { |
| k.replace("early_mlp.", "", 1): v |
| for k, v in state.items() |
| if k.startswith("early_mlp.") |
| } |
| if early_state: |
| missing, unexpected = model.encoder.early_mlp.load_state_dict(early_state, strict=False) |
| print_master( |
| f"Loaded early_mlp weights from {ckpt_dir} (missing={missing}, unexpected={unexpected})" |
| ) |
|
|
| |
| bin_path = os.path.join(ckpt_dir, "early_mlp.bin") |
| if os.path.exists(bin_path): |
| try: |
| early_bin = torch.load(bin_path, map_location="cpu") |
| missing, unexpected = model.encoder.early_mlp.load_state_dict(early_bin, strict=False) |
| print_master( |
| f"Loaded early_mlp weights from early_mlp.bin (missing={missing}, unexpected={unexpected})" |
| ) |
| except Exception as e: |
| print_master(f"Failed to load early_mlp from early_mlp.bin: {e}") |
| |
| return model |
|
|
| def save(self, output_dir: str): |
| self.encoder.save_pretrained(output_dir) |
|
|
| def forward(self, qry: Dict[str, torch.Tensor] = None, tgt: Dict[str, torch.Tensor] = None, |
| compute_early_only: bool = False, *args, **kwargs): |
| |
| only_q = (qry is not None) and (tgt is None) |
| only_t = (tgt is not None) and (qry is None) |
| if only_q or only_t: |
| |
| single = qry if only_q else tgt |
| rep = self.encode_input(single, return_early=False) |
| return {"qry_reps": rep if only_q else None, "tgt_reps": rep if only_t else None} |
| |
| if qry is None and tgt is None: |
| raise ValueError("MMEBModel.forward expected 'qry' and/or 'tgt' but got none. " |
| "Check split_and_process_vlm_inputs / training_step packaging.") |
| |
|
|
| |
| if compute_early_only: |
| if not self.enable_early_mlp: |
| raise RuntimeError("compute_early_only=True 但 early_mlp 未启用") |
| qry_final, qry_early = self.encode_input(qry, return_early=True) |
| tgt_final, tgt_early = self.encode_input(tgt, return_early=True) |
| assert qry_early is not None and tgt_early is not None |
| qry_reps, tgt_reps = qry_early, tgt_early |
| else: |
| if self.enable_early_mlp and (self.training and self.early_loss_weight > 0): |
| qry_final, qry_early = self.encode_input(qry, return_early=True) |
| tgt_final, tgt_early = self.encode_input(tgt, return_early=True) |
| else: |
| qry_final = self.encode_input(qry, return_early=False) |
| tgt_final = self.encode_input(tgt, return_early=False) |
| qry_early, tgt_early = None, None |
| qry_reps, tgt_reps = qry_final, tgt_final |
|
|
| |
| if self.is_ddp: |
| all_qry_reps = self._dist_gather_tensor(qry_reps) |
| all_tgt_reps = self._dist_gather_tensor(tgt_reps) |
| else: |
| all_qry_reps = qry_reps |
| all_tgt_reps = tgt_reps |
|
|
| |
| scores = self.compute_similarity(all_qry_reps, all_tgt_reps).view(all_qry_reps.size(0), -1) |
| target = torch.arange(scores.size(0), device=scores.device, dtype=torch.long) |
| target = target * (all_qry_reps.size(0) // all_tgt_reps.size(0)) |
| loss = self.cross_entropy(scores / self.temperature, target) |
| if self.is_ddp: |
| loss = loss * self.world_size |
|
|
| |
| if (not compute_early_only) and self.training and self.enable_early_mlp and self.early_loss_weight > 0: |
| assert qry_early is not None and tgt_early is not None |
| if self.is_ddp: |
| all_qry_e = self._dist_gather_tensor(qry_early) |
| all_tgt_e = self._dist_gather_tensor(tgt_early) |
| else: |
| all_qry_e, all_tgt_e = qry_early, tgt_early |
| scores_e = self.compute_similarity(all_qry_e, all_tgt_e).view(all_qry_e.size(0), -1) |
| target_e = torch.arange(scores_e.size(0), device=scores_e.device, dtype=torch.long) |
| target_e = target_e * (all_qry_e.size(0) // all_tgt_e.size(0)) |
| loss_e = self.cross_entropy(scores_e / self.temperature, target_e) |
| if self.is_ddp: |
| loss_e = loss_e * self.world_size |
| loss = loss + self.early_loss_weight * loss_e |
|
|
| return loss |
|
|
| def _dist_gather_tensor(self, t: Tensor): |
| t = t.contiguous() |
| all_tensors = [torch.empty_like(t) for _ in range(self.world_size)] |
| dist.all_gather(all_tensors, t) |
| all_tensors[self.process_rank] = t |
| all_tensors = torch.cat(all_tensors, dim=0) |
| return all_tensors |
|
|
| def compute_similarity(self, q_reps, p_reps): |
| return torch.matmul(q_reps, p_reps.transpose(0, 1)) |
|
|