File size: 8,654 Bytes
8d017ca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
"""

TMOS_Classifier: Binary classification head on top of LLaVA's transformer backbone.



Strips the autoregressive lm_head and replaces it with a single nn.Linear(hidden_size, 1)

for binary deepfake detection (0 = Real, 1 = Fake).



Usage:

    from tmos_classifier import TMOSClassifier, TMOS_LORA_CONFIG

    

    classifier = TMOSClassifier(base_model_id="llava-hf/llava-1.5-7b-hf")

    classifier = get_peft_model(classifier, TMOS_LORA_CONFIG)

    

    logit = classifier(input_ids=..., pixel_values=..., attention_mask=...)

    loss  = nn.BCEWithLogitsLoss()(logit, label)

"""

import torch
import torch.nn as nn
from transformers import LlavaForConditionalGeneration
from peft import LoraConfig


# ─── LoRA Configuration ──────────────────────────────────────────────
# Massive expansion: r=64 across ALL linear layers in the LLM backbone.
# We exclude lm_head (we discard it), fc1/fc2/out_proj (CLIP vision),
# and linear_1/linear_2 (multi-modal projector) from LoRA to keep
# the vision encoder frozen and only adapt the language transformer.

TMOS_LORA_CONFIG = LoraConfig(
    r=64,
    lora_alpha=128,               # 2x rank as a common heuristic
    target_modules=[
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj",
    ],
    lora_dropout=0.1,
    bias="none",
    task_type=None,                # Custom classifier β€” not a causal LM
    modules_to_save=["classifier"],  # Always train the classification head
)


class TMOSClassifier(nn.Module):
    """

    Binary classifier built on the LLaVA transformer backbone.

    

    Architecture:

        pixel_values ──► CLIP Vision Tower ──► Multi-Modal Projector ──┐

                                                                       β”œβ”€β”€β–Ί LLaMA Transformer ──► last_hidden_state[:, -1, :] ──► classifier ──► logit

        input_ids ──► Token Embedding β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜

    

    The lm_head is never used. We extract the final token's hidden state 

    and pass it through a learned nn.Linear(hidden_size, 1) head.

    """

    def __init__(self, base_model_id, torch_dtype=torch.float16, device_map="auto", token=None):
        super().__init__()

        # Load the full LLaVA model (we need vision tower + projector + LLM)
        self.base = LlavaForConditionalGeneration.from_pretrained(
            base_model_id,
            torch_dtype=torch_dtype,
            low_cpu_mem_usage=True,
            device_map=device_map,
            token=token,
        )

        hidden_size = self.base.config.text_config.hidden_size  # 4096 for 7B

        # Freeze the lm_head β€” we won't use it, but freezing prevents
        # wasted gradient computation if PEFT accidentally wraps it.
        for param in self.base.lm_head.parameters():
            param.requires_grad = False

        # Keep the classifier head in fp32 for numerical stability.
        self.classifier = nn.Linear(hidden_size, 1, dtype=torch.float32)
        nn.init.xavier_uniform_(self.classifier.weight)
        nn.init.zeros_(self.classifier.bias)

    def forward(

        self,

        input_ids=None,

        pixel_values=None,

        attention_mask=None,

        labels=None,            # float tensor of shape (B,) β€” 0.0=real, 1.0=fake

        **kwargs,               # absorb extra keys from data collator

    ):
        """

        Single deterministic forward pass β†’ logit + optional BCE loss.

        

        Returns:

            dict with keys:

                "logit":  (B, 1) raw logit  

                "loss":   scalar BCE loss (only if labels provided)

        """
        # ── 1. Forward through the LLaVA backbone ──
        # We call the internal model (vision + projector + LLM) directly,
        # asking for hidden states, NOT for language-model logits.
        outputs = self.base.model(
            input_ids=input_ids,
            pixel_values=pixel_values,
            attention_mask=attention_mask,
            return_dict=True,
        )

        # last_hidden_state: (B, seq_len, hidden_size)
        last_hidden_state = outputs.last_hidden_state

        # ── 2. Pool: extract the final non-padded token per sequence ──
        if attention_mask is not None:
            # Sum of mask gives the sequence length (excluding padding)
            # Index of the last real token = seq_lengths - 1
            seq_lengths = attention_mask.sum(dim=1).long() - 1
            # Clamp to valid range
            seq_lengths = seq_lengths.clamp(min=0, max=last_hidden_state.size(1) - 1)
            # Gather the hidden state at each sequence's last real token
            pooled = last_hidden_state[
                torch.arange(last_hidden_state.size(0), device=last_hidden_state.device),
                seq_lengths,
            ]
        else:
            # No mask β†’ just take the last position
            pooled = last_hidden_state[:, -1, :]

        # Replace non-finite activations defensively before the classifier.
        pooled = torch.nan_to_num(pooled, nan=0.0, posinf=1e4, neginf=-1e4)

        # Match classifier device to pooled features when model is sharded/offloaded.
        if self.classifier.weight.device != pooled.device:
            self.classifier = self.classifier.to(pooled.device)

        # ── 3. Classify ──
        logit = self.classifier(pooled.float())    # (B, 1)
        logit = torch.nan_to_num(logit, nan=0.0, posinf=20.0, neginf=-20.0)

        result = {"logit": logit}

        # ── 4. Loss ──
        if labels is not None:
            labels = labels.to(logit.dtype).to(logit.device)
            if labels.dim() == 1:
                labels = labels.unsqueeze(1)   # (B,) β†’ (B, 1)
            loss_fn = nn.BCEWithLogitsLoss()
            result["loss"] = loss_fn(logit, labels)

        return result

    def prepare_inputs_for_generation(self, *args, **kwargs):
        """Stub required by PEFT β€” we never generate text."""
        raise NotImplementedError("TMOSClassifier does not support generation.")

    def gradient_checkpointing_enable(self, **kwargs):
        """Delegate to the base model for HF Trainer compatibility."""
        self.base.model.gradient_checkpointing_enable(**kwargs)

    @property
    def config(self):
        """Expose the base model config for PEFT."""
        return self.base.config

    @property
    def device(self):
        return next(self.parameters()).device

    @property 
    def dtype(self):
        return next(self.parameters()).dtype


# ─── Standalone Test ──────────────────────────────────────────────────
if __name__ == "__main__":
    import os
    from dotenv import load_dotenv
    load_dotenv()
    HF_TOKEN = os.getenv("HF_TOKEN")

    print("Testing TMOSClassifier...")
    device = "cuda" if torch.cuda.is_available() else "cpu"

    clf = TMOSClassifier(
        base_model_id="llava-hf/llava-1.5-7b-hf",
        torch_dtype=torch.float16,
        token=HF_TOKEN,
    )
    clf.to(device)

    # Print parameter counts
    total = sum(p.numel() for p in clf.parameters())
    trainable = sum(p.numel() for p in clf.parameters() if p.requires_grad)
    print(f"Total params:     {total:>12,}")
    print(f"Trainable params: {trainable:>12,}")
    print(f"Classifier head:  {sum(p.numel() for p in clf.classifier.parameters()):,}")

    # Smoke test with dummy input
    from transformers import AutoProcessor
    processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf", token=HF_TOKEN)
    processor.patch_size = 14
    processor.vision_feature_select_strategy = "default"

    from PIL import Image
    dummy_img = Image.new("RGB", (336, 336), color=(128, 128, 128))
    inputs = processor(
        text="USER: <image>\nIs this real?\nASSISTANT:",
        images=dummy_img,
        return_tensors="pt",
    ).to(device)

    labels = torch.tensor([1.0], device=device)  # fake

    with torch.no_grad():
        out = clf(**inputs, labels=labels)
    
    print(f"Logit:  {out['logit'].item():.4f}")
    print(f"Loss:   {out['loss'].item():.4f}")
    print(f"Prob:   {torch.sigmoid(out['logit']).item():.4f}")
    print("Test passed.")