File size: 8,393 Bytes
0ddafe6 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 | """v0.1版本"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModel, CLIPVisionModel, PreTrainedModel, PretrainedConfig
class MiniEmbedVisionConfig(PretrainedConfig):
model_type = "miniembedvision"
def __init__(
self,
embed_dim: int = 768,
text_model_name: str = "BAAI/bge-base-zh-v1.5",
vision_model_name: str = "openai/clip-vit-base-patch32",
freeze_text: bool = True,
use_lora: bool = False,
lora_r: int = 8,
lora_alpha: int = 16,
lora_dropout: float = 0.05,
**kwargs
):
super().__init__(**kwargs)
self.embed_dim = embed_dim
self.text_model_name = text_model_name
self.vision_model_name = vision_model_name
self.freeze_text = freeze_text
self.use_lora = use_lora
self.lora_r = lora_r
self.lora_alpha = lora_alpha
self.lora_dropout = lora_dropout
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModel, CLIPVisionModel, PreTrainedModel
from transformers.modeling_outputs import ModelOutput
from typing import Optional, Tuple, Union
try:
from peft import LoraConfig, get_peft_model, TaskType
_has_peft = True
except ImportError:
_has_peft = False
def _concat_all_gather(tensor):
if not torch.distributed.is_initialized():
return tensor
tensors_gather = [torch.zeros_like(tensor) for _ in range(torch.distributed.get_world_size())]
torch.distributed.all_gather(tensors_gather, tensor)
return torch.cat(tensors_gather, dim=0)
class MiniEmbedVisionModel(PreTrainedModel):
config_class = MiniEmbedVisionConfig
def __init__(self, config):
super().__init__(config)
# 文本编码器
self.text_encoder = AutoModel.from_pretrained(config.text_model_name)
text_hidden = self.text_encoder.config.hidden_size
# 若冻结文本编码器 freeze_text
if config.freeze_text:
for p in self.text_encoder.parameters():
p.requires_grad = False
self.text_encoder.eval()
# 视觉编码器
self.vision_encoder = CLIPVisionModel.from_pretrained(config.vision_model_name)
vis_hidden = self.vision_encoder.config.hidden_size
# lora config
if config.use_lora:
if not _has_peft:
raise ImportError("peft is required for LoRA. Please install: pip install peft")
lora_config = LoraConfig(
r=config.lora_r,
lora_alpha=config.lora_alpha,
target_modules=["q_proj", "k_proj", "v_proj", "out_proj", "fc1", "fc2", "proj"],
lora_dropout=config.lora_dropout,
bias="none",
task_type=TaskType.OTHER,
)
self.vision_encoder = get_peft_model(self.vision_encoder, lora_config)
self.text_proj = nn.Linear(text_hidden, config.embed_dim) if text_hidden != config.embed_dim else nn.Identity()
self.vision_proj = nn.Linear(vis_hidden, config.embed_dim) if vis_hidden != config.embed_dim else nn.Identity()
self.logit_scale = nn.Parameter(torch.ones([]) * torch.log(torch.tensor(1 / 0.07)))
def _bge_pool(self, last_hidden, attention_mask):
mask = attention_mask.unsqueeze(-1).expand(last_hidden.size()).float()
sum_emb = torch.sum(last_hidden * mask, dim=1)
sum_mask = mask.sum(dim=1)
sum_mask = torch.clamp(sum_mask, min=1e-9)
return sum_emb / sum_mask
def encode_text(self, input_ids, attention_mask):
with torch.no_grad() if not any(p.requires_grad for p in self.text_encoder.parameters()) else torch.enable_grad():
outputs = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask, return_dict=True)
pooled = self._bge_pool(outputs.last_hidden_state, attention_mask)
emb = self.text_proj(pooled)
return F.normalize(emb, p=2, dim=-1)
def encode_image(self, pixel_values):
outputs = self.vision_encoder(pixel_values=pixel_values, return_dict=True)
cls_feat = outputs.last_hidden_state[:, 0]
emb = self.vision_proj(cls_feat)
return F.normalize(emb, p=2, dim=-1)
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
pixel_values: Optional[torch.Tensor] = None,
return_loss: bool = False,
gather_for_ddp: bool = True,
labels: Optional[torch.Tensor] = None, # reserved for supervised contrastive
) -> Union[ModelOutput, Tuple[torch.Tensor, ...]]:
"""
模拟 Nomic 风格:既支持文本也支持图像
Supports:
- Inference: only text OR only image → returns embeddings.
- Training: both text AND image → returns logits and (optionally) loss.
"""
text_emb = None
image_emb = None
# Encode modalities if provided
if input_ids is not None and attention_mask is not None:
text_emb = self.encode_text(input_ids, attention_mask)
if pixel_values is not None:
image_emb = self.encode_image(pixel_values)
# Inference mode: single modality
if text_emb is not None and image_emb is None:
return ModelOutput(last_hidden_state=text_emb, text_embeds=text_emb)
if image_emb is not None and text_emb is None:
return ModelOutput(last_hidden_state=image_emb, image_embeds=image_emb)
# Training mode: both modalities present
if text_emb is None or image_emb is None:
raise ValueError("For training, both text and image inputs are required.")
# Gather across GPUs for large-batch negatives
if gather_for_ddp and torch.distributed.is_initialized():
text_emb_all = _concat_all_gather(text_emb)
image_emb_all = _concat_all_gather(image_emb)
else:
text_emb_all = text_emb
image_emb_all = image_emb
logit_scale = torch.clamp(self.logit_scale.exp(), max=100.0)
logits_per_text = logit_scale * text_emb @ image_emb_all.t() # [B, global_B]
logits_per_image = logits_per_text.t() # [global_B, B]
loss = None
if return_loss:
# Assume 1:1 pairing in local batch
local_batch_size = text_emb.size(0)
global_batch_size = text_emb_all.size(0)
# Create labels: local i matches global i
labels = torch.arange(local_batch_size, device=text_emb.device)
# But logits_per_text is [local_B, global_B], so we need to align
# Standard CLIP: each local text matches its corresponding image in global list
# Find the global indices of local images
if torch.distributed.is_initialized():
rank = torch.distributed.get_rank()
local_image_start = rank * local_batch_size
image_indices = torch.arange(local_image_start, local_image_start + local_batch_size, device=text_emb.device)
# We can't easily compute loss without knowing global alignment
# Simpler: only use local batch for loss (common in practice)
logits_local = logit_scale * text_emb @ image_emb.t() # [B, B]
labels_local = torch.arange(local_batch_size, device=text_emb.device)
loss_i = F.cross_entropy(logits_local, labels_local)
loss_t = F.cross_entropy(logits_local.t(), labels_local)
loss = (loss_i + loss_t) / 2
else:
logits_local = logits_per_text # [B, B]
labels_local = torch.arange(local_batch_size, device=text_emb.device)
loss_i = F.cross_entropy(logits_local, labels_local)
loss_t = F.cross_entropy(logits_local.t(), labels_local)
loss = (loss_i + loss_t) / 2
return ModelOutput(
loss=loss,
logits_per_text=logits_per_text,
logits_per_image=logits_per_image,
text_embeds=text_emb,
image_embeds=image_emb,
)
MiniEmbedVisionConfig.register_for_auto_class()
MiniEmbedVisionModel.register_for_auto_class("AutoModel") |