syjarvis's picture
Upload model
0ddafe6 verified
"""v0.1版本"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModel, CLIPVisionModel, PreTrainedModel, PretrainedConfig
class MiniEmbedVisionConfig(PretrainedConfig):
model_type = "miniembedvision"
def __init__(
self,
embed_dim: int = 768,
text_model_name: str = "BAAI/bge-base-zh-v1.5",
vision_model_name: str = "openai/clip-vit-base-patch32",
freeze_text: bool = True,
use_lora: bool = False,
lora_r: int = 8,
lora_alpha: int = 16,
lora_dropout: float = 0.05,
**kwargs
):
super().__init__(**kwargs)
self.embed_dim = embed_dim
self.text_model_name = text_model_name
self.vision_model_name = vision_model_name
self.freeze_text = freeze_text
self.use_lora = use_lora
self.lora_r = lora_r
self.lora_alpha = lora_alpha
self.lora_dropout = lora_dropout
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModel, CLIPVisionModel, PreTrainedModel
from transformers.modeling_outputs import ModelOutput
from typing import Optional, Tuple, Union
try:
from peft import LoraConfig, get_peft_model, TaskType
_has_peft = True
except ImportError:
_has_peft = False
def _concat_all_gather(tensor):
if not torch.distributed.is_initialized():
return tensor
tensors_gather = [torch.zeros_like(tensor) for _ in range(torch.distributed.get_world_size())]
torch.distributed.all_gather(tensors_gather, tensor)
return torch.cat(tensors_gather, dim=0)
class MiniEmbedVisionModel(PreTrainedModel):
config_class = MiniEmbedVisionConfig
def __init__(self, config):
super().__init__(config)
# 文本编码器
self.text_encoder = AutoModel.from_pretrained(config.text_model_name)
text_hidden = self.text_encoder.config.hidden_size
# 若冻结文本编码器 freeze_text
if config.freeze_text:
for p in self.text_encoder.parameters():
p.requires_grad = False
self.text_encoder.eval()
# 视觉编码器
self.vision_encoder = CLIPVisionModel.from_pretrained(config.vision_model_name)
vis_hidden = self.vision_encoder.config.hidden_size
# lora config
if config.use_lora:
if not _has_peft:
raise ImportError("peft is required for LoRA. Please install: pip install peft")
lora_config = LoraConfig(
r=config.lora_r,
lora_alpha=config.lora_alpha,
target_modules=["q_proj", "k_proj", "v_proj", "out_proj", "fc1", "fc2", "proj"],
lora_dropout=config.lora_dropout,
bias="none",
task_type=TaskType.OTHER,
)
self.vision_encoder = get_peft_model(self.vision_encoder, lora_config)
self.text_proj = nn.Linear(text_hidden, config.embed_dim) if text_hidden != config.embed_dim else nn.Identity()
self.vision_proj = nn.Linear(vis_hidden, config.embed_dim) if vis_hidden != config.embed_dim else nn.Identity()
self.logit_scale = nn.Parameter(torch.ones([]) * torch.log(torch.tensor(1 / 0.07)))
def _bge_pool(self, last_hidden, attention_mask):
mask = attention_mask.unsqueeze(-1).expand(last_hidden.size()).float()
sum_emb = torch.sum(last_hidden * mask, dim=1)
sum_mask = mask.sum(dim=1)
sum_mask = torch.clamp(sum_mask, min=1e-9)
return sum_emb / sum_mask
def encode_text(self, input_ids, attention_mask):
with torch.no_grad() if not any(p.requires_grad for p in self.text_encoder.parameters()) else torch.enable_grad():
outputs = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask, return_dict=True)
pooled = self._bge_pool(outputs.last_hidden_state, attention_mask)
emb = self.text_proj(pooled)
return F.normalize(emb, p=2, dim=-1)
def encode_image(self, pixel_values):
outputs = self.vision_encoder(pixel_values=pixel_values, return_dict=True)
cls_feat = outputs.last_hidden_state[:, 0]
emb = self.vision_proj(cls_feat)
return F.normalize(emb, p=2, dim=-1)
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
pixel_values: Optional[torch.Tensor] = None,
return_loss: bool = False,
gather_for_ddp: bool = True,
labels: Optional[torch.Tensor] = None, # reserved for supervised contrastive
) -> Union[ModelOutput, Tuple[torch.Tensor, ...]]:
"""
模拟 Nomic 风格:既支持文本也支持图像
Supports:
- Inference: only text OR only image → returns embeddings.
- Training: both text AND image → returns logits and (optionally) loss.
"""
text_emb = None
image_emb = None
# Encode modalities if provided
if input_ids is not None and attention_mask is not None:
text_emb = self.encode_text(input_ids, attention_mask)
if pixel_values is not None:
image_emb = self.encode_image(pixel_values)
# Inference mode: single modality
if text_emb is not None and image_emb is None:
return ModelOutput(last_hidden_state=text_emb, text_embeds=text_emb)
if image_emb is not None and text_emb is None:
return ModelOutput(last_hidden_state=image_emb, image_embeds=image_emb)
# Training mode: both modalities present
if text_emb is None or image_emb is None:
raise ValueError("For training, both text and image inputs are required.")
# Gather across GPUs for large-batch negatives
if gather_for_ddp and torch.distributed.is_initialized():
text_emb_all = _concat_all_gather(text_emb)
image_emb_all = _concat_all_gather(image_emb)
else:
text_emb_all = text_emb
image_emb_all = image_emb
logit_scale = torch.clamp(self.logit_scale.exp(), max=100.0)
logits_per_text = logit_scale * text_emb @ image_emb_all.t() # [B, global_B]
logits_per_image = logits_per_text.t() # [global_B, B]
loss = None
if return_loss:
# Assume 1:1 pairing in local batch
local_batch_size = text_emb.size(0)
global_batch_size = text_emb_all.size(0)
# Create labels: local i matches global i
labels = torch.arange(local_batch_size, device=text_emb.device)
# But logits_per_text is [local_B, global_B], so we need to align
# Standard CLIP: each local text matches its corresponding image in global list
# Find the global indices of local images
if torch.distributed.is_initialized():
rank = torch.distributed.get_rank()
local_image_start = rank * local_batch_size
image_indices = torch.arange(local_image_start, local_image_start + local_batch_size, device=text_emb.device)
# We can't easily compute loss without knowing global alignment
# Simpler: only use local batch for loss (common in practice)
logits_local = logit_scale * text_emb @ image_emb.t() # [B, B]
labels_local = torch.arange(local_batch_size, device=text_emb.device)
loss_i = F.cross_entropy(logits_local, labels_local)
loss_t = F.cross_entropy(logits_local.t(), labels_local)
loss = (loss_i + loss_t) / 2
else:
logits_local = logits_per_text # [B, B]
labels_local = torch.arange(local_batch_size, device=text_emb.device)
loss_i = F.cross_entropy(logits_local, labels_local)
loss_t = F.cross_entropy(logits_local.t(), labels_local)
loss = (loss_i + loss_t) / 2
return ModelOutput(
loss=loss,
logits_per_text=logits_per_text,
logits_per_image=logits_per_image,
text_embeds=text_emb,
image_embeds=image_emb,
)
MiniEmbedVisionConfig.register_for_auto_class()
MiniEmbedVisionModel.register_for_auto_class("AutoModel")