| """v0.1版本""" |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from transformers import AutoModel, CLIPVisionModel, PreTrainedModel, PretrainedConfig |
|
|
| class MiniEmbedVisionConfig(PretrainedConfig): |
|
|
| model_type = "miniembedvision" |
|
|
| def __init__( |
| self, |
| embed_dim: int = 768, |
| text_model_name: str = "BAAI/bge-base-zh-v1.5", |
| vision_model_name: str = "openai/clip-vit-base-patch32", |
| freeze_text: bool = True, |
| use_lora: bool = False, |
| lora_r: int = 8, |
| lora_alpha: int = 16, |
| lora_dropout: float = 0.05, |
| **kwargs |
| ): |
| super().__init__(**kwargs) |
| self.embed_dim = embed_dim |
| self.text_model_name = text_model_name |
| self.vision_model_name = vision_model_name |
| self.freeze_text = freeze_text |
| self.use_lora = use_lora |
| self.lora_r = lora_r |
| self.lora_alpha = lora_alpha |
| self.lora_dropout = lora_dropout |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from transformers import AutoModel, CLIPVisionModel, PreTrainedModel |
| from transformers.modeling_outputs import ModelOutput |
| from typing import Optional, Tuple, Union |
|
|
| try: |
| from peft import LoraConfig, get_peft_model, TaskType |
| _has_peft = True |
| except ImportError: |
| _has_peft = False |
|
|
| def _concat_all_gather(tensor): |
| if not torch.distributed.is_initialized(): |
| return tensor |
| tensors_gather = [torch.zeros_like(tensor) for _ in range(torch.distributed.get_world_size())] |
| torch.distributed.all_gather(tensors_gather, tensor) |
| return torch.cat(tensors_gather, dim=0) |
|
|
| class MiniEmbedVisionModel(PreTrainedModel): |
| config_class = MiniEmbedVisionConfig |
|
|
| def __init__(self, config): |
| super().__init__(config) |
| |
| |
| self.text_encoder = AutoModel.from_pretrained(config.text_model_name) |
| text_hidden = self.text_encoder.config.hidden_size |
| |
| if config.freeze_text: |
| for p in self.text_encoder.parameters(): |
| p.requires_grad = False |
| self.text_encoder.eval() |
| |
| |
| self.vision_encoder = CLIPVisionModel.from_pretrained(config.vision_model_name) |
| vis_hidden = self.vision_encoder.config.hidden_size |
|
|
| |
| if config.use_lora: |
| if not _has_peft: |
| raise ImportError("peft is required for LoRA. Please install: pip install peft") |
| lora_config = LoraConfig( |
| r=config.lora_r, |
| lora_alpha=config.lora_alpha, |
| target_modules=["q_proj", "k_proj", "v_proj", "out_proj", "fc1", "fc2", "proj"], |
| lora_dropout=config.lora_dropout, |
| bias="none", |
| task_type=TaskType.OTHER, |
| ) |
| self.vision_encoder = get_peft_model(self.vision_encoder, lora_config) |
|
|
| self.text_proj = nn.Linear(text_hidden, config.embed_dim) if text_hidden != config.embed_dim else nn.Identity() |
| self.vision_proj = nn.Linear(vis_hidden, config.embed_dim) if vis_hidden != config.embed_dim else nn.Identity() |
| self.logit_scale = nn.Parameter(torch.ones([]) * torch.log(torch.tensor(1 / 0.07))) |
| |
| def _bge_pool(self, last_hidden, attention_mask): |
| mask = attention_mask.unsqueeze(-1).expand(last_hidden.size()).float() |
| sum_emb = torch.sum(last_hidden * mask, dim=1) |
| sum_mask = mask.sum(dim=1) |
| sum_mask = torch.clamp(sum_mask, min=1e-9) |
| return sum_emb / sum_mask |
|
|
| def encode_text(self, input_ids, attention_mask): |
| with torch.no_grad() if not any(p.requires_grad for p in self.text_encoder.parameters()) else torch.enable_grad(): |
| outputs = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask, return_dict=True) |
| pooled = self._bge_pool(outputs.last_hidden_state, attention_mask) |
| emb = self.text_proj(pooled) |
| return F.normalize(emb, p=2, dim=-1) |
|
|
| def encode_image(self, pixel_values): |
| outputs = self.vision_encoder(pixel_values=pixel_values, return_dict=True) |
| cls_feat = outputs.last_hidden_state[:, 0] |
| emb = self.vision_proj(cls_feat) |
| return F.normalize(emb, p=2, dim=-1) |
| |
| def forward( |
| self, |
| input_ids: Optional[torch.Tensor] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| pixel_values: Optional[torch.Tensor] = None, |
| return_loss: bool = False, |
| gather_for_ddp: bool = True, |
| labels: Optional[torch.Tensor] = None, |
| ) -> Union[ModelOutput, Tuple[torch.Tensor, ...]]: |
| """ |
| 模拟 Nomic 风格:既支持文本也支持图像 |
| Supports: |
| - Inference: only text OR only image → returns embeddings. |
| - Training: both text AND image → returns logits and (optionally) loss. |
| """ |
| text_emb = None |
| image_emb = None |
|
|
| |
| if input_ids is not None and attention_mask is not None: |
| text_emb = self.encode_text(input_ids, attention_mask) |
| if pixel_values is not None: |
| image_emb = self.encode_image(pixel_values) |
|
|
| |
| if text_emb is not None and image_emb is None: |
| return ModelOutput(last_hidden_state=text_emb, text_embeds=text_emb) |
| if image_emb is not None and text_emb is None: |
| return ModelOutput(last_hidden_state=image_emb, image_embeds=image_emb) |
|
|
| |
| if text_emb is None or image_emb is None: |
| raise ValueError("For training, both text and image inputs are required.") |
|
|
| |
| if gather_for_ddp and torch.distributed.is_initialized(): |
| text_emb_all = _concat_all_gather(text_emb) |
| image_emb_all = _concat_all_gather(image_emb) |
| else: |
| text_emb_all = text_emb |
| image_emb_all = image_emb |
|
|
| logit_scale = torch.clamp(self.logit_scale.exp(), max=100.0) |
| logits_per_text = logit_scale * text_emb @ image_emb_all.t() |
| logits_per_image = logits_per_text.t() |
|
|
| loss = None |
| if return_loss: |
| |
| local_batch_size = text_emb.size(0) |
| global_batch_size = text_emb_all.size(0) |
| |
| labels = torch.arange(local_batch_size, device=text_emb.device) |
| |
| |
| |
| if torch.distributed.is_initialized(): |
| rank = torch.distributed.get_rank() |
| local_image_start = rank * local_batch_size |
| image_indices = torch.arange(local_image_start, local_image_start + local_batch_size, device=text_emb.device) |
| |
| |
| logits_local = logit_scale * text_emb @ image_emb.t() |
| labels_local = torch.arange(local_batch_size, device=text_emb.device) |
| loss_i = F.cross_entropy(logits_local, labels_local) |
| loss_t = F.cross_entropy(logits_local.t(), labels_local) |
| loss = (loss_i + loss_t) / 2 |
| else: |
| logits_local = logits_per_text |
| labels_local = torch.arange(local_batch_size, device=text_emb.device) |
| loss_i = F.cross_entropy(logits_local, labels_local) |
| loss_t = F.cross_entropy(logits_local.t(), labels_local) |
| loss = (loss_i + loss_t) / 2 |
|
|
| return ModelOutput( |
| loss=loss, |
| logits_per_text=logits_per_text, |
| logits_per_image=logits_per_image, |
| text_embeds=text_emb, |
| image_embeds=image_emb, |
| ) |
|
|
| MiniEmbedVisionConfig.register_for_auto_class() |
| MiniEmbedVisionModel.register_for_auto_class("AutoModel") |