| from typing import Dict |
| import torch, os |
| import torch.distributed as dist |
| from torch import nn, Tensor |
| import torch.nn.functional as F |
| from transformers import PreTrainedModel, AutoModelForCausalLM, AutoConfig |
| from peft import LoraConfig, get_peft_model, PeftModel |
| from src.model.processor import QWEN2_5_VL_TOKENSELECTION |
| from src.arguments import ModelArguments, TrainingArguments |
| from src.model.processor import LLAVA_NEXT, QWEN2_VL, PHI3V, get_backbone_name, print_master, QWEN2_5_VL, \ |
| backbone2model, QWEN2_VL_TOKENSELECTION, QWEN2_5_VL_TOKENSELECTION, E5_V |
|
|
| from src.arguments import ModelArguments |
| from src.model.processor import LLAVA_NEXT, QWEN2_VL, PHI3V, get_backbone_name, print_master, QWEN2_5_VL, INTERNVIDEO2, \ |
| QWEN2_VL_TOKENSELECTION, backbone2model, GME, VLM_IMAGE_TOKENS, LamRA, LamRA_QWEN2_5, COLPALI |
| from src.model.baseline_backbone.colpali import ColPali |
| from src.model.baseline_backbone.gme.gme_inference import GmeQwen2VL |
| from src.model.baseline_backbone.lamra.lamra_inference import LamRAQwen2VL |
| from src.model.baseline_backbone.lamra.lamra_qwen25_inference import LamRAQwen25VL |
| from src.model.baseline_backbone.phi3_v.modeling_phi3_v import Phi3VForCausalLM |
| from src.model.baseline_backbone.llava_next import LlavaNextForConditionalGeneration |
|
|
| from transformers import modeling_utils |
| if not hasattr(modeling_utils, "ALL_PARALLEL_STYLES") or modeling_utils.ALL_PARALLEL_STYLES is None: |
| modeling_utils.ALL_PARALLEL_STYLES = ["tp", "none", "colwise", 'rowwise'] |
|
|
|
|
| class MMEBModel(nn.Module): |
| TRANSFORMER_CLS = AutoModelForCausalLM |
|
|
| def __init__(self, |
| encoder: PreTrainedModel, |
| pooling: str = 'last', |
| normalize: bool = False, |
| temperature: float = 0.02, |
| ): |
| super().__init__() |
| self.config = encoder.config |
| self.encoder = encoder |
| self.pooling = pooling |
| self.normalize = normalize |
| self.temperature = temperature |
| self.cross_entropy = nn.CrossEntropyLoss(reduction='mean') |
| self.is_ddp = dist.is_initialized() |
| if self.is_ddp: |
| self.process_rank = dist.get_rank() |
| self.world_size = dist.get_world_size() |
| |
| self.layer_indices = [20, -1] |
| self.dual_layer_idx = 20 |
| self.dual_alpha = 0.05 |
|
|
| |
| hidden_size = getattr(self.config, "hidden_size", None) |
| if hidden_size is None: |
| |
| hidden_size = getattr(self.encoder.config, "hidden_size", 1024) |
| self.proj20 = nn.Sequential( |
| nn.LayerNorm(hidden_size), |
| nn.Linear(hidden_size, hidden_size), |
| nn.GELU(), |
| nn.Dropout(0.1), |
| nn.Linear(hidden_size, hidden_size), |
| ) |
|
|
| |
| self.training_stage = 1 |
|
|
| def _is_qwen2_series(self): |
| return getattr(self, "model_backbone", None) in { |
| QWEN2_VL, QWEN2_5_VL, QWEN2_VL_TOKENSELECTION, QWEN2_5_VL_TOKENSELECTION |
| } |
|
|
| def _squeeze_mm_inputs(self, inp: Dict) -> Dict: |
| |
| x = dict(inp) |
| if "pixel_values" in x and isinstance(x["pixel_values"], torch.Tensor) and x["pixel_values"].dim() == 5: |
| |
| x["pixel_values"] = x["pixel_values"].squeeze(1) |
| if "image_sizes" in x and isinstance(x["image_sizes"], torch.Tensor) and x["image_sizes"].dim() >= 3: |
| |
| x["image_sizes"] = x["image_sizes"].squeeze(1) |
| return x |
| |
| |
| def set_training_stage(self, stage: int, freeze_encoder: bool = True, verbose: bool = True): |
| """ |
| stage=1: 仅使用20层(加MLP)的loss,并仅训练MLP(默认冻结encoder) |
| stage=2: 使用20层(加MLP)+最后一层的loss,训练整个模型 |
| 注意:需在创建优化器之前调用,或在阶段切换后重新创建优化器。 |
| """ |
| assert stage in (1, 2), "stage 只能为 1 或 2" |
| self.training_stage = stage |
|
|
| if freeze_encoder: |
| if stage == 1: |
| |
| for p in self.parameters(): |
| p.requires_grad = False |
| for p in self.proj20.parameters(): |
| p.requires_grad = True |
| if verbose: |
| print("[MMEB] Stage 1: 冻结 encoder,仅训练 proj20") |
| else: |
| |
| for p in self.parameters(): |
| p.requires_grad = True |
| if verbose: |
| print("[MMEB] Stage 2: 训练全模型(含 proj20 和 encoder)") |
|
|
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| def _encode_20_raw(self, input): |
| """ |
| 取第20层的池化向量(不投影)。对需要的骨干做必要的输入整形。 |
| """ |
| mb = getattr(self, "model_backbone", None) |
| idx20 = int(getattr(self, "dual_layer_idx", 20)) |
|
|
| |
| if mb in {LLAVA_NEXT, QWEN2_VL, QWEN2_5_VL, QWEN2_VL_TOKENSELECTION, QWEN2_5_VL_TOKENSELECTION}: |
| inp = dict(input) |
| if "pixel_values" in inp and isinstance(inp["pixel_values"], torch.Tensor) and inp["pixel_values"].dim() == 5: |
| |
| inp["pixel_values"] = inp["pixel_values"].squeeze(1) |
| if "image_sizes" in inp and isinstance(inp["image_sizes"], torch.Tensor) and inp["image_sizes"].dim() >= 3: |
| |
| inp["image_sizes"] = inp["image_sizes"].squeeze(1) |
|
|
| out = self.encoder(**inp, return_dict=True, output_hidden_states=True) |
| hs = getattr(out, "hidden_states", None) |
| if hs is None: |
| |
| if hasattr(out, "text_hidden_states"): |
| hs = out.text_hidden_states |
| else: |
| raise RuntimeError("hidden_states is None; ensure output_hidden_states=True and trust_remote_code=True.") |
|
|
| |
| L = len(hs) - 1 |
| if idx20 < 0: |
| layer_idx = idx20 |
| else: |
| layer_idx = max(1, min(idx20, L)) if L >= 1 else -1 |
|
|
| rep20 = self._pooling(hs[layer_idx], inp["attention_mask"]) |
| return rep20 |
|
|
| |
| if mb in {INTERNVIDEO2, GME, LamRA, LamRA_QWEN2_5, COLPALI}: |
| last = self.encode_input(input) |
| return last |
|
|
| |
| out = self.encoder(**input, return_dict=True, output_hidden_states=True) |
| hs = out.hidden_states |
| L = len(hs) - 1 |
| if idx20 < 0: |
| layer_idx = idx20 |
| else: |
| layer_idx = max(1, min(idx20, L)) if L >= 1 else -1 |
| rep20 = self._pooling(hs[layer_idx], input["attention_mask"]) |
| return rep20 |
|
|
| |
| def _encode_20_proj(self, input): |
| rep20 = self._encode_20_raw(input) |
| rep20 = self.proj20(rep20) |
| if self.normalize: |
| rep20 = F.normalize(rep20, p=2, dim=-1) |
| return rep20 |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| def _encode_query_dual(self, input): |
| """ |
| 返回 [B, 2, D]: 第20层(过MLP) 与 最后一层 的池化向量。 |
| 对不支持 hidden_states 的 backbone,回退为两份相同的最后一层(第一份过MLP)。 |
| """ |
| mb = getattr(self, "model_backbone", None) |
|
|
| def norm(x): |
| return F.normalize(x, p=2, dim=-1) if self.normalize else x |
|
|
| |
| if mb not in [GME, LamRA, LamRA_QWEN2_5, INTERNVIDEO2, COLPALI, LLAVA_NEXT] and not self._is_qwen2_series(): |
| out = self.encoder(**input, return_dict=True, output_hidden_states=True) |
| hs = out.hidden_states |
| idx20 = self.dual_layer_idx |
| if idx20 < 0: |
| idx20 = len(hs) + idx20 |
| idx20 = max(1, min(idx20, len(hs) - 1)) |
|
|
| rep20 = self._pooling(hs[idx20], input['attention_mask']) |
| replast = self._pooling(hs[-1], input['attention_mask']) |
| rep20 = self.proj20(rep20) |
| rep20, replast = norm(rep20), norm(replast) |
| return torch.stack([rep20, replast], dim=1) |
|
|
| |
| if mb == LLAVA_NEXT or self._is_qwen2_series(): |
| inp = self._squeeze_mm_inputs(input) |
| out = self.encoder(**inp, return_dict=True, output_hidden_states=True) |
| hs = out.hidden_states |
| idx20 = self.dual_layer_idx |
| if idx20 < 0: |
| idx20 = len(hs) + idx20 |
| idx20 = max(1, min(idx20, len(hs) - 1)) |
| rep20 = self._pooling(hs[idx20], inp['attention_mask']) |
| replast = self._pooling(hs[-1], inp['attention_mask']) |
| rep20 = self.proj20(rep20) |
| rep20, replast = norm(rep20), norm(replast) |
| return torch.stack([rep20, replast], dim=1) |
|
|
| |
| last = self.encode_input(input) |
| rep20 = self.proj20(last) |
| if self.normalize: |
| rep20 = F.normalize(rep20, p=2, dim=-1) |
| last = F.normalize(last, p=2, dim=-1) |
| return torch.stack([rep20, last], dim=1) |
|
|
| |
| |
| |
| |
| |
| |
|
|
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| def _encode_target_dual(self, input): |
| """ |
| 返回 [B, 2, D]: cand 的第20层(过MLP) 与 最后一层 池化向量。 |
| 对不支持 hidden_states 的 backbone,回退为两份相同的最后一层(第一份过MLP)。 |
| """ |
| mb = getattr(self, "model_backbone", None) |
|
|
| def norm(x): |
| return F.normalize(x, p=2, dim=-1) if self.normalize else x |
|
|
| if mb not in [GME, LamRA, LamRA_QWEN2_5, INTERNVIDEO2, COLPALI, LLAVA_NEXT] and not self._is_qwen2_series(): |
| out = self.encoder(**input, return_dict=True, output_hidden_states=True) |
| hs = out.hidden_states |
| idx20 = self.dual_layer_idx |
| if idx20 < 0: |
| idx20 = len(hs) + idx20 |
| idx20 = max(1, min(idx20, len(hs) - 1)) |
| rep20 = self._pooling(hs[idx20], input['attention_mask']) |
| replast = self._pooling(hs[-1], input['attention_mask']) |
| rep20 = self.proj20(rep20) |
| rep20, replast = norm(rep20), norm(replast) |
| return torch.stack([rep20, replast], dim=1) |
|
|
| if mb == LLAVA_NEXT or self._is_qwen2_series(): |
| inp = self._squeeze_mm_inputs(input) |
| out = self.encoder(**inp, return_dict=True, output_hidden_states=True) |
| hs = out.hidden_states |
| idx20 = self.dual_layer_idx |
| if idx20 < 0: |
| idx20 = len(hs) + idx20 |
| idx20 = max(1, min(idx20, len(hs) - 1)) |
| rep20 = self._pooling(hs[idx20], inp['attention_mask']) |
| replast = self._pooling(hs[-1], inp['attention_mask']) |
| rep20 = self.proj20(rep20) |
| rep20, replast = norm(rep20), norm(replast) |
| return torch.stack([rep20, replast], dim=1) |
|
|
| last = self.encode_input(input) |
| rep20 = self.proj20(last) |
| if self.normalize: |
| rep20 = F.normalize(rep20, p=2, dim=-1) |
| last = F.normalize(last, p=2, dim=-1) |
| return torch.stack([rep20, last], dim=1) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| def encode_input(self, input, layer_indices=None): |
| mb = getattr(self, "model_backbone", None) |
|
|
| if mb == INTERNVIDEO2: |
| if "input_ids" in input.keys(): |
| text_output = self.encoder.get_text_encoder()( |
| input["input_ids"], |
| attention_mask=input["attention_mask"], |
| return_dict=True, |
| mode="text", |
| ) |
| text_embeds = text_output.last_hidden_state |
| pooled_text_embeds = text_embeds[:, 0] |
| pooled_output = self.encoder.text_proj(pooled_text_embeds) |
| pooled_output /= pooled_output.norm(dim=-1, keepdim=True) |
| return pooled_output |
| else: |
| _, vfeat = self.encoder.encode_vision(input["pixel_values"], test=True) |
| vfeat = self.encoder.vision_proj(vfeat) |
| vfeat /= vfeat.norm(dim=-1, keepdim=True) |
| return vfeat |
|
|
| elif mb in [GME, LamRA, LamRA_QWEN2_5]: |
| texts = [text.replace(VLM_IMAGE_TOKENS[QWEN2_VL] + '\n', '') for text in input["texts"]] |
| images = [] |
| for imgs in input['images']: |
| if isinstance(imgs, list): |
| imgs = imgs[len(imgs) // 2] |
| assert not isinstance(imgs, list) |
| images.append(imgs) |
| else: |
| images.append(imgs) |
| pooled_output = self.encoder.get_fused_embeddings(texts=texts, images=images) |
| return pooled_output |
|
|
| elif mb == COLPALI: |
| pooled_output = self.encoder(**input, return_dict=True, output_hidden_states=True) |
| return pooled_output |
|
|
| elif mb == LLAVA_NEXT or self._is_qwen2_series(): |
| inp = self._squeeze_mm_inputs(input) |
| out = self.encoder(**inp, return_dict=True, output_hidden_states=True) |
| h_last = out.hidden_states[-1] |
| pooled_output = self._pooling(h_last, inp['attention_mask']) |
| return pooled_output |
|
|
| else: |
| out = self.encoder(**input, return_dict=True, output_hidden_states=True) |
| hs_list = out.hidden_states |
| if layer_indices is None or isinstance(layer_indices, int): |
| h = hs_list[-1] if layer_indices is None else hs_list[layer_indices] |
| reps = self._pooling(h, input['attention_mask']) |
| return reps |
| else: |
| reps_list = [] |
| for idx in layer_indices: |
| h = hs_list[idx] |
| r = self._pooling(h, input['attention_mask']) |
| reps_list.append(r) |
| return torch.stack(reps_list, dim=1) |
|
|
| def _pooling(self, last_hidden_state, attention_mask): |
| if self.pooling == 'last' or self.pooling == 'eos': |
| left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0]) |
| batch_size = last_hidden_state.shape[0] |
| if left_padding: |
| |
| reps = last_hidden_state[torch.arange(batch_size), -1, :] |
| else: |
| |
| eos_indices = attention_mask.sum(dim=1) - 1 |
| |
| reps = last_hidden_state[ |
| torch.arange(batch_size, device=last_hidden_state.device), eos_indices] |
| else: |
| raise NotImplementedError |
| if self.normalize: |
| reps = torch.nn.functional.normalize(reps, p=2, dim=-1) |
| return reps |
|
|
| @classmethod |
| def build(cls, model_args: ModelArguments, **kwargs): |
| config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True) |
| variant = getattr(config, "backbone_variant", None) |
| if variant == "layerprune": |
| model_backbone = "QWEN2_VL_LayerPrune" |
| else: |
| model_backbone = get_backbone_name(hf_config=config) |
| print_master(f'Loading backbone [{model_backbone}] from {model_args.model_name}') |
| |
| if model_backbone == PHI3V: |
| config._attn_implementation = "eager" |
| config.padding_side = "right" |
| config.use_cache = False |
| base_model = Phi3VForCausalLM.from_pretrained( |
| model_args.model_name, |
| config=config, |
| torch_dtype=torch.bfloat16, |
| low_cpu_mem_usage=True, |
| ) |
| elif model_backbone == LLAVA_NEXT: |
| config.use_cache = False |
| config.padding_side = "left" |
| base_model = LlavaNextForConditionalGeneration.from_pretrained( |
| model_args.model_name, |
| config=config, |
| torch_dtype=torch.bfloat16, |
| low_cpu_mem_usage=True, |
| ) |
| elif model_backbone in [QWEN2_VL, QWEN2_5_VL]: |
| config._attn_implementation = "flash_attention_2" |
| config.padding_side = "left" |
| config.use_cache = False |
| base_model = backbone2model[model_backbone].from_pretrained( |
| model_args.model_name, |
| config=config, |
| torch_dtype=torch.bfloat16, |
| low_cpu_mem_usage=True, |
| ) |
| elif model_backbone in ["QWEN2_VL_LayerPrune"]: |
| config._attn_implementation = "flash_attention_2" |
| config.padding_side = "left" |
| config.use_cache = False |
| base_model = backbone2model[model_backbone].from_pretrained( |
| model_args.model_name, |
| config=config, |
| torch_dtype=torch.bfloat16, |
| low_cpu_mem_usage=True, |
| ) |
| elif model_backbone in [QWEN2_VL_TOKENSELECTION, QWEN2_5_VL_TOKENSELECTION]: |
| config._attn_implementation = "flash_attention_2" |
| config.padding_side = "left" |
| config.use_cache = False |
|
|
| from .utils import parse_layer_type |
| lm_qwen_layer = 28 |
| vis_qwen_layer = 32 |
| lm_skip_layer = parse_layer_type(model_args.lm_skip_layer, lm_qwen_layer) |
| vis_skip_layer = parse_layer_type(model_args.vis_skip_layer, vis_qwen_layer) |
|
|
| base_model = backbone2model[model_backbone].from_pretrained( |
| model_args.model_name, |
| config=config, |
| torch_dtype=torch.bfloat16, |
| low_cpu_mem_usage=True, |
| lm_skip_layer=lm_skip_layer, |
| vis_skip_layer=vis_skip_layer, |
| ) |
| else: |
| config.use_cache = False |
| base_model = cls.TRANSFORMER_CLS.from_pretrained( |
| model_args.model_name, **kwargs, config=config, |
| attn_implementation="flash_attention_2", |
| torch_dtype=torch.bfloat16, |
| trust_remote_code=True) |
|
|
| if model_args.lora: |
| print_master(f'Loading lora adapter from {base_model}') |
| lora_config = LoraConfig( |
| r=model_args.lora_r, |
| lora_alpha=model_args.lora_alpha, |
| target_modules=model_args.lora_target_modules.split(','), |
| lora_dropout=model_args.lora_dropout, |
| init_lora_weights="gaussian", |
| use_dora=True, |
| inference_mode=False |
| ) |
| lora_model = get_peft_model(base_model, lora_config) |
| model = cls( |
| encoder=lora_model, |
| pooling=model_args.pooling, |
| normalize=model_args.normalize, |
| temperature=model_args.temperature |
| ) |
| else: |
| model = cls( |
| encoder=base_model, |
| pooling=model_args.pooling, |
| normalize=model_args.normalize, |
| temperature=model_args.temperature |
| ) |
| return model |
|
|
| @classmethod |
| def load(cls, model_args: ModelArguments, is_trainable=True, **kwargs): |
| |
| model_name_or_path = model_args.checkpoint_path if model_args.checkpoint_path else model_args.model_name |
| config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True) |
| if not hasattr(model_args, "model_backbone") or not model_args.model_backbone: |
| model_backbone = get_backbone_name(hf_config=config, model_type=model_args.model_type) |
| setattr(model_args, 'model_backbone', model_backbone) |
| print_master(f'Loading backbone [{model_args.model_backbone}] from {model_name_or_path}') |
| if model_args.model_backbone in {LLAVA_NEXT, QWEN2_VL, QWEN2_5_VL, QWEN2_VL_TOKENSELECTION, QWEN2_5_VL_TOKENSELECTION, E5_V}: |
| config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True) |
| config._attn_implementation = "flash_attention_2" |
| config.vision_config._attn_implementation = "flash_attention_2" |
| base_model = backbone2model[model_args.model_backbone].from_pretrained( |
| model_args.model_name, |
| torch_dtype=torch.bfloat16, |
| low_cpu_mem_usage=True, |
| config=config |
| ) |
| elif model_args.model_backbone == PHI3V: |
| config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True) |
| config.use_cache = False |
| config.padding_side = "right" |
| base_model = Phi3VForCausalLM.from_pretrained(model_args.model_name, **kwargs, config=config, |
| torch_dtype=torch.bfloat16, trust_remote_code=True) |
| base_model.padding_side = "right" |
| elif model_args.model_backbone == INTERNVIDEO2: |
| print_master(f'Loading backbone [{model_args.model_backbone}] from {"src/model/vlm_backbone/internvideo2/"}') |
| config = AutoConfig.from_pretrained("src/model/vlm_backbone/internvideo2/", |
| trust_remote_code=True) |
| base_model = backbone2model[model_args.model_backbone].from_pretrained("src/model/vlm_backbone/internvideo2/", config=config, |
| trust_remote_code=True) |
| elif model_args.model_backbone == GME: |
| base_model = GmeQwen2VL(model_args.model_name, processor=kwargs['processor']) |
| setattr(base_model, 'config', config) |
| elif model_args.model_backbone == LamRA: |
| base_model = LamRAQwen2VL(model_args.model_name) |
| setattr(base_model, 'config', config) |
| elif model_args.model_backbone == LamRA_QWEN2_5: |
| base_model = LamRAQwen25VL(model_args.model_name) |
| setattr(base_model, 'config', config) |
| elif model_args.model_backbone == COLPALI: |
| base_model = ColPali.from_pretrained(model_args.model_name) |
| setattr(base_model, 'config', config) |
| else: |
| |
| config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True) |
| config.use_cache = False |
| base_model = cls.TRANSFORMER_CLS.from_pretrained( |
| model_name_or_path, **kwargs, config=config, |
| torch_dtype=torch.bfloat16, |
| trust_remote_code=True) |
|
|
| |
| if model_args.lora: |
| print_master(f'Loading LoRA from {model_name_or_path}') |
| lora_config = LoraConfig.from_pretrained(model_name_or_path) |
| lora_model = PeftModel.from_pretrained(base_model, model_name_or_path, config=lora_config, is_trainable=is_trainable) |
| lora_model.load_adapter(model_name_or_path, lora_model.active_adapter, is_trainable=is_trainable) |
| if not is_trainable: |
| lora_model = lora_model.merge_and_unload() |
| model = cls( |
| encoder=lora_model, |
| pooling=model_args.pooling, |
| normalize=model_args.normalize, |
| temperature=model_args.temperature |
| ) |
| else: |
| model = cls( |
| encoder=base_model, |
| pooling=model_args.pooling, |
| normalize=model_args.normalize, |
| temperature=model_args.temperature |
| ) |
|
|
| model.model_backbone = model_args.model_backbone |
| try: |
| ckpt_dir = model_args.checkpoint_path or model_args.model_name |
| extra_candidates = [ |
| os.path.join(ckpt_dir, "mmeb_extra.pt"), |
| os.path.join(ckpt_dir, "mmeb_extra.bin"), |
| os.path.join(ckpt_dir, "extra_heads.pt"), |
| os.path.join(ckpt_dir, "proj20.pt"), |
| ] |
| extra_path = next((p for p in extra_candidates if os.path.isfile(p)), None) |
| if extra_path: |
| extra_sd = torch.load(extra_path, map_location="cpu") |
| missing, unexpected = model.load_state_dict(extra_sd, strict=False) |
| print_master(f"Loaded extra heads from {extra_path}. " |
| f"missing={len(missing)}, unexpected={len(unexpected)}") |
| except Exception as e: |
| print_master(f"[WARN] Failed to load extra heads: {e}") |
| return model |
|
|
| def save(self, output_dir: str): |
| self.encoder.save_pretrained(output_dir) |
|
|
| def forward(self, qry: Dict[str, Tensor] = None, tgt: Dict[str, Tensor] = None, *args, **kwargs): |
| |
| if qry is not None and tgt is None: |
| if self.training_stage == 1: |
| |
| qry_reps = self._encode_20_proj(qry) |
| else: |
| |
| qry_reps = self._encode_query_dual(qry) |
| return {"qry_reps": qry_reps, "tgt_reps": None} |
|
|
| if tgt is not None and qry is None: |
| if self.training_stage == 1: |
| tgt_reps = self._encode_20_proj(tgt) |
| else: |
| tgt_reps = self._encode_target_dual(tgt) |
| return {"qry_reps": None, "tgt_reps": tgt_reps} |
|
|
| |
| if qry is not None and tgt is not None: |
| if self.training_stage == 1: |
| |
| q = self._encode_20_proj(qry) |
| t = self._encode_20_proj(tgt) |
| if self.is_ddp: |
| q = self._dist_gather_tensor(q) |
| t = self._dist_gather_tensor(t) |
| logits = torch.matmul(q, t.transpose(0, 1)) / self.temperature |
| target = torch.arange(logits.size(0), device=logits.device, dtype=torch.long) |
| loss = self.cross_entropy(logits, target) |
| if self.is_ddp: |
| loss = loss * self.world_size |
| return loss |
| else: |
| |
| q = self._encode_query_dual(qry) |
| t = self._encode_target_dual(tgt) |
| if self.is_ddp: |
| q = self._dist_gather_tensor(q) |
| t = self._dist_gather_tensor(t) |
| B = q.size(0) |
| labels = torch.arange(B, device=q.device, dtype=torch.long) |
| alpha = getattr(self, "dual_alpha", 0.2) |
| |
| loss20 = self.cross_entropy((q[:, 0, :] @ t[:, 0, :].T) / self.temperature, labels) |
| lossL = self.cross_entropy((q[:, 1, :] @ t[:, 1, :].T) / self.temperature, labels) |
| loss = alpha * loss20 + (1.0 - alpha) * lossL |
| if self.is_ddp: |
| loss = loss * self.world_size |
| return loss |
|
|
| return {"qry_reps": None, "tgt_reps": None} |
|
|
| def _dist_gather_tensor(self, t: Tensor): |
| t = t.contiguous() |
| all_tensors = [torch.empty_like(t) for _ in range(self.world_size)] |
| dist.all_gather(all_tensors, t) |
| all_tensors[self.process_rank] = t |
| all_tensors = torch.cat(all_tensors, dim=0) |
| return all_tensors |
|
|
| def compute_similarity(self, q_reps, p_reps): |
| return torch.matmul(q_reps, p_reps.transpose(0, 1)) |
|
|