""" TMOS_Classifier: Binary classification head on top of LLaVA's transformer backbone. Strips the autoregressive lm_head and replaces it with a single nn.Linear(hidden_size, 1) for binary deepfake detection (0 = Real, 1 = Fake). Usage: from tmos_classifier import TMOSClassifier, TMOS_LORA_CONFIG classifier = TMOSClassifier(base_model_id="llava-hf/llava-1.5-7b-hf") classifier = get_peft_model(classifier, TMOS_LORA_CONFIG) logit = classifier(input_ids=..., pixel_values=..., attention_mask=...) loss = nn.BCEWithLogitsLoss()(logit, label) """ import torch import torch.nn as nn from transformers import LlavaForConditionalGeneration from peft import LoraConfig # ─── LoRA Configuration ────────────────────────────────────────────── # Massive expansion: r=64 across ALL linear layers in the LLM backbone. # We exclude lm_head (we discard it), fc1/fc2/out_proj (CLIP vision), # and linear_1/linear_2 (multi-modal projector) from LoRA to keep # the vision encoder frozen and only adapt the language transformer. TMOS_LORA_CONFIG = LoraConfig( r=64, lora_alpha=128, # 2x rank as a common heuristic target_modules=[ "q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj", ], lora_dropout=0.1, bias="none", task_type=None, # Custom classifier — not a causal LM modules_to_save=["classifier"], # Always train the classification head ) class TMOSClassifier(nn.Module): """ Binary classifier built on the LLaVA transformer backbone. Architecture: pixel_values ──► CLIP Vision Tower ──► Multi-Modal Projector ──┐ ├──► LLaMA Transformer ──► last_hidden_state[:, -1, :] ──► classifier ──► logit input_ids ──► Token Embedding ─────────────────────────────────┘ The lm_head is never used. We extract the final token's hidden state and pass it through a learned nn.Linear(hidden_size, 1) head. """ def __init__(self, base_model_id, torch_dtype=torch.float16, device_map="auto", token=None): super().__init__() # Load the full LLaVA model (we need vision tower + projector + LLM) self.base = LlavaForConditionalGeneration.from_pretrained( base_model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, device_map=device_map, token=token, ) hidden_size = self.base.config.text_config.hidden_size # 4096 for 7B # Freeze the lm_head — we won't use it, but freezing prevents # wasted gradient computation if PEFT accidentally wraps it. for param in self.base.lm_head.parameters(): param.requires_grad = False # Keep the classifier head in fp32 for numerical stability. self.classifier = nn.Linear(hidden_size, 1, dtype=torch.float32) nn.init.xavier_uniform_(self.classifier.weight) nn.init.zeros_(self.classifier.bias) def forward( self, input_ids=None, pixel_values=None, attention_mask=None, labels=None, # float tensor of shape (B,) — 0.0=real, 1.0=fake **kwargs, # absorb extra keys from data collator ): """ Single deterministic forward pass → logit + optional BCE loss. Returns: dict with keys: "logit": (B, 1) raw logit "loss": scalar BCE loss (only if labels provided) """ # ── 1. Forward through the LLaVA backbone ── # We call the internal model (vision + projector + LLM) directly, # asking for hidden states, NOT for language-model logits. outputs = self.base.model( input_ids=input_ids, pixel_values=pixel_values, attention_mask=attention_mask, return_dict=True, ) # last_hidden_state: (B, seq_len, hidden_size) last_hidden_state = outputs.last_hidden_state # ── 2. Pool: extract the final non-padded token per sequence ── if attention_mask is not None: # Sum of mask gives the sequence length (excluding padding) # Index of the last real token = seq_lengths - 1 seq_lengths = attention_mask.sum(dim=1).long() - 1 # Clamp to valid range seq_lengths = seq_lengths.clamp(min=0, max=last_hidden_state.size(1) - 1) # Gather the hidden state at each sequence's last real token pooled = last_hidden_state[ torch.arange(last_hidden_state.size(0), device=last_hidden_state.device), seq_lengths, ] else: # No mask → just take the last position pooled = last_hidden_state[:, -1, :] # Replace non-finite activations defensively before the classifier. pooled = torch.nan_to_num(pooled, nan=0.0, posinf=1e4, neginf=-1e4) # Match classifier device to pooled features when model is sharded/offloaded. if self.classifier.weight.device != pooled.device: self.classifier = self.classifier.to(pooled.device) # ── 3. Classify ── logit = self.classifier(pooled.float()) # (B, 1) logit = torch.nan_to_num(logit, nan=0.0, posinf=20.0, neginf=-20.0) result = {"logit": logit} # ── 4. Loss ── if labels is not None: labels = labels.to(logit.dtype).to(logit.device) if labels.dim() == 1: labels = labels.unsqueeze(1) # (B,) → (B, 1) loss_fn = nn.BCEWithLogitsLoss() result["loss"] = loss_fn(logit, labels) return result def prepare_inputs_for_generation(self, *args, **kwargs): """Stub required by PEFT — we never generate text.""" raise NotImplementedError("TMOSClassifier does not support generation.") def gradient_checkpointing_enable(self, **kwargs): """Delegate to the base model for HF Trainer compatibility.""" self.base.model.gradient_checkpointing_enable(**kwargs) @property def config(self): """Expose the base model config for PEFT.""" return self.base.config @property def device(self): return next(self.parameters()).device @property def dtype(self): return next(self.parameters()).dtype # ─── Standalone Test ────────────────────────────────────────────────── if __name__ == "__main__": import os from dotenv import load_dotenv load_dotenv() HF_TOKEN = os.getenv("HF_TOKEN") print("Testing TMOSClassifier...") device = "cuda" if torch.cuda.is_available() else "cpu" clf = TMOSClassifier( base_model_id="llava-hf/llava-1.5-7b-hf", torch_dtype=torch.float16, token=HF_TOKEN, ) clf.to(device) # Print parameter counts total = sum(p.numel() for p in clf.parameters()) trainable = sum(p.numel() for p in clf.parameters() if p.requires_grad) print(f"Total params: {total:>12,}") print(f"Trainable params: {trainable:>12,}") print(f"Classifier head: {sum(p.numel() for p in clf.classifier.parameters()):,}") # Smoke test with dummy input from transformers import AutoProcessor processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf", token=HF_TOKEN) processor.patch_size = 14 processor.vision_feature_select_strategy = "default" from PIL import Image dummy_img = Image.new("RGB", (336, 336), color=(128, 128, 128)) inputs = processor( text="USER: \nIs this real?\nASSISTANT:", images=dummy_img, return_tensors="pt", ).to(device) labels = torch.tensor([1.0], device=device) # fake with torch.no_grad(): out = clf(**inputs, labels=labels) print(f"Logit: {out['logit'].item():.4f}") print(f"Loss: {out['loss'].item():.4f}") print(f"Prob: {torch.sigmoid(out['logit']).item():.4f}") print("Test passed.")