File size: 8,393 Bytes
0ddafe6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
"""v0.1版本"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModel, CLIPVisionModel, PreTrainedModel, PretrainedConfig

class MiniEmbedVisionConfig(PretrainedConfig):

    model_type = "miniembedvision"

    def __init__(
        self,
        embed_dim: int =  768,
        text_model_name: str = "BAAI/bge-base-zh-v1.5",
        vision_model_name: str = "openai/clip-vit-base-patch32",
        freeze_text: bool = True,
        use_lora: bool = False,
        lora_r: int = 8,
        lora_alpha: int = 16,
        lora_dropout: float = 0.05,
        **kwargs
    ):
        super().__init__(**kwargs)
        self.embed_dim = embed_dim
        self.text_model_name = text_model_name
        self.vision_model_name = vision_model_name
        self.freeze_text = freeze_text
        self.use_lora = use_lora
        self.lora_r = lora_r
        self.lora_alpha = lora_alpha
        self.lora_dropout = lora_dropout

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModel, CLIPVisionModel, PreTrainedModel
from transformers.modeling_outputs import ModelOutput
from typing import Optional, Tuple, Union

try:
    from peft import LoraConfig, get_peft_model, TaskType
    _has_peft = True
except ImportError:
    _has_peft = False

def _concat_all_gather(tensor):
    if not torch.distributed.is_initialized():
        return tensor
    tensors_gather = [torch.zeros_like(tensor) for _ in range(torch.distributed.get_world_size())]
    torch.distributed.all_gather(tensors_gather, tensor)
    return torch.cat(tensors_gather, dim=0)

class MiniEmbedVisionModel(PreTrainedModel):
    config_class = MiniEmbedVisionConfig

    def __init__(self, config):
        super().__init__(config)
        
        # 文本编码器
        self.text_encoder = AutoModel.from_pretrained(config.text_model_name)
        text_hidden = self.text_encoder.config.hidden_size
        # 若冻结文本编码器 freeze_text
        if config.freeze_text:
            for p in self.text_encoder.parameters():
                p.requires_grad = False
            self.text_encoder.eval()
        
        # 视觉编码器
        self.vision_encoder = CLIPVisionModel.from_pretrained(config.vision_model_name)
        vis_hidden = self.vision_encoder.config.hidden_size

        # lora config
        if config.use_lora:
            if not _has_peft:
                raise ImportError("peft is required for LoRA. Please install: pip install peft")
            lora_config = LoraConfig(
                r=config.lora_r,
                lora_alpha=config.lora_alpha,
                target_modules=["q_proj", "k_proj", "v_proj", "out_proj", "fc1", "fc2", "proj"],
                lora_dropout=config.lora_dropout,
                bias="none",
                task_type=TaskType.OTHER,
            )
            self.vision_encoder = get_peft_model(self.vision_encoder, lora_config)

        self.text_proj = nn.Linear(text_hidden, config.embed_dim) if text_hidden != config.embed_dim else nn.Identity()
        self.vision_proj = nn.Linear(vis_hidden, config.embed_dim) if vis_hidden != config.embed_dim else nn.Identity()
        self.logit_scale = nn.Parameter(torch.ones([]) * torch.log(torch.tensor(1 / 0.07)))
    
    def _bge_pool(self, last_hidden, attention_mask):
        mask = attention_mask.unsqueeze(-1).expand(last_hidden.size()).float()
        sum_emb = torch.sum(last_hidden * mask, dim=1)
        sum_mask = mask.sum(dim=1)
        sum_mask = torch.clamp(sum_mask, min=1e-9)
        return sum_emb / sum_mask

    def encode_text(self, input_ids, attention_mask):
        with torch.no_grad() if not any(p.requires_grad for p in self.text_encoder.parameters()) else torch.enable_grad():
            outputs = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask, return_dict=True)
            pooled = self._bge_pool(outputs.last_hidden_state, attention_mask)
        emb = self.text_proj(pooled)
        return F.normalize(emb, p=2, dim=-1)

    def encode_image(self, pixel_values):
        outputs = self.vision_encoder(pixel_values=pixel_values, return_dict=True)
        cls_feat = outputs.last_hidden_state[:, 0]
        emb = self.vision_proj(cls_feat)
        return F.normalize(emb, p=2, dim=-1)
    
    def forward(
        self,
        input_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        pixel_values: Optional[torch.Tensor] = None,
        return_loss: bool = False,
        gather_for_ddp: bool = True,
        labels: Optional[torch.Tensor] = None,  # reserved for supervised contrastive
    ) -> Union[ModelOutput, Tuple[torch.Tensor, ...]]:
        """
        模拟 Nomic 风格:既支持文本也支持图像
        Supports:
          - Inference: only text OR only image → returns embeddings.
          - Training: both text AND image → returns logits and (optionally) loss.
        """
        text_emb = None
        image_emb = None

        # Encode modalities if provided
        if input_ids is not None and attention_mask is not None:
            text_emb = self.encode_text(input_ids, attention_mask)
        if pixel_values is not None:
            image_emb = self.encode_image(pixel_values)

        # Inference mode: single modality
        if text_emb is not None and image_emb is None:
            return ModelOutput(last_hidden_state=text_emb, text_embeds=text_emb)
        if image_emb is not None and text_emb is None:
            return ModelOutput(last_hidden_state=image_emb, image_embeds=image_emb)

        # Training mode: both modalities present
        if text_emb is None or image_emb is None:
            raise ValueError("For training, both text and image inputs are required.")

        # Gather across GPUs for large-batch negatives
        if gather_for_ddp and torch.distributed.is_initialized():
            text_emb_all = _concat_all_gather(text_emb)
            image_emb_all = _concat_all_gather(image_emb)
        else:
            text_emb_all = text_emb
            image_emb_all = image_emb

        logit_scale = torch.clamp(self.logit_scale.exp(), max=100.0)
        logits_per_text = logit_scale * text_emb @ image_emb_all.t()   # [B, global_B]
        logits_per_image = logits_per_text.t()                         # [global_B, B]

        loss = None
        if return_loss:
            # Assume 1:1 pairing in local batch
            local_batch_size = text_emb.size(0)
            global_batch_size = text_emb_all.size(0)
            # Create labels: local i matches global i
            labels = torch.arange(local_batch_size, device=text_emb.device)
            # But logits_per_text is [local_B, global_B], so we need to align
            # Standard CLIP: each local text matches its corresponding image in global list
            # Find the global indices of local images
            if torch.distributed.is_initialized():
                rank = torch.distributed.get_rank()
                local_image_start = rank * local_batch_size
                image_indices = torch.arange(local_image_start, local_image_start + local_batch_size, device=text_emb.device)
                # We can't easily compute loss without knowing global alignment
                # Simpler: only use local batch for loss (common in practice)
                logits_local = logit_scale * text_emb @ image_emb.t()  # [B, B]
                labels_local = torch.arange(local_batch_size, device=text_emb.device)
                loss_i = F.cross_entropy(logits_local, labels_local)
                loss_t = F.cross_entropy(logits_local.t(), labels_local)
                loss = (loss_i + loss_t) / 2
            else:
                logits_local = logits_per_text  # [B, B]
                labels_local = torch.arange(local_batch_size, device=text_emb.device)
                loss_i = F.cross_entropy(logits_local, labels_local)
                loss_t = F.cross_entropy(logits_local.t(), labels_local)
                loss = (loss_i + loss_t) / 2

        return ModelOutput(
            loss=loss,
            logits_per_text=logits_per_text,
            logits_per_image=logits_per_image,
            text_embeds=text_emb,
            image_embeds=image_emb,
        )

MiniEmbedVisionConfig.register_for_auto_class()
MiniEmbedVisionModel.register_for_auto_class("AutoModel")