""" Preprocessing Pipeline for Multimodal Deepfake Detection ========================================================= Handles: - Image preprocessing (resize, normalize, augment) - Video frame extraction and preprocessing - Text tokenization and preprocessing - Dataset loading and formatting """ import torch from torch.utils.data import Dataset, DataLoader from torchvision import transforms from PIL import Image import numpy as np import io # ============================================================ # Image/Video Preprocessing # ============================================================ def get_image_transforms(mode='train', image_size=224): """Get image transforms for training or evaluation. Based on DeepfakeBench preprocessing pipeline: - Resize to target size - Data augmentation for training (flip, color jitter, blur) - Normalize with ImageNet stats """ if mode == 'train': return transforms.Compose([ transforms.Resize((image_size, image_size)), transforms.RandomHorizontalFlip(p=0.5), transforms.RandomRotation(degrees=10), transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.1, hue=0.05), transforms.RandomAffine(degrees=0, translate=(0.05, 0.05)), transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 1.0)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), transforms.RandomErasing(p=0.1), ]) else: return transforms.Compose([ transforms.Resize((image_size, image_size)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) def preprocess_image(image, transform): """Preprocess a single image (PIL or tensor). Args: image: PIL Image or bytes transform: torchvision transform pipeline Returns: tensor: (C, H, W) preprocessed image tensor """ if isinstance(image, bytes): image = Image.open(io.BytesIO(image)) if isinstance(image, dict) and 'bytes' in image: image = Image.open(io.BytesIO(image['bytes'])) if not isinstance(image, Image.Image): raise ValueError(f"Expected PIL Image, got {type(image)}") image = image.convert('RGB') return transform(image) def extract_video_frames(video_path, num_frames=32, uniform=True): """Extract frames from video for deepfake detection. Based on DeepfakeBench: sample 32 frames uniformly. Args: video_path: Path to video file num_frames: Number of frames to extract uniform: Whether to sample uniformly or randomly Returns: frames: list of PIL Images """ try: import cv2 except ImportError: raise ImportError("OpenCV required for video processing: pip install opencv-python") cap = cv2.VideoCapture(video_path) total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) if total_frames <= 0: raise ValueError(f"Cannot read video: {video_path}") if uniform: indices = np.linspace(0, total_frames - 1, num_frames, dtype=int) else: indices = sorted(np.random.choice(total_frames, min(num_frames, total_frames), replace=False)) frames = [] for idx in indices: cap.set(cv2.CAP_PROP_POS_FRAMES, idx) ret, frame = cap.read() if ret: frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) frames.append(Image.fromarray(frame_rgb)) cap.release() return frames # ============================================================ # Text Preprocessing # ============================================================ def get_tokenizer(model_name='roberta-base', max_length=512): """Get tokenizer for text branch.""" from transformers import AutoTokenizer tokenizer = AutoTokenizer.from_pretrained(model_name) return tokenizer, max_length def preprocess_text(text, tokenizer, max_length=512): """Tokenize text for the text branch. Args: text: input string tokenizer: HF tokenizer max_length: maximum sequence length Returns: dict with input_ids and attention_mask tensors """ encoding = tokenizer( text, max_length=max_length, padding='max_length', truncation=True, return_tensors='pt' ) return { 'input_ids': encoding['input_ids'].squeeze(0), 'attention_mask': encoding['attention_mask'].squeeze(0) } # ============================================================ # Dataset Classes # ============================================================ class ImageDeepfakeDataset(Dataset): """Dataset for image-based deepfake detection. Compatible with Hemg/deepfake-and-real-images format: - image: PIL Image - label: 0=Fake, 1=Real (we flip to 0=Real, 1=Fake for consistency) """ def __init__(self, hf_dataset, transform=None, label_column='label', image_column='image', flip_labels=True): self.dataset = hf_dataset self.transform = transform or get_image_transforms('train') self.label_column = label_column self.image_column = image_column self.flip_labels = flip_labels # Hemg dataset: Fake=0, Real=1 def __len__(self): return len(self.dataset) def __getitem__(self, idx): item = self.dataset[idx] image = item[self.image_column] if isinstance(image, dict) and 'bytes' in image: image = Image.open(io.BytesIO(image['bytes'])) elif isinstance(image, bytes): image = Image.open(io.BytesIO(image)) if isinstance(image, Image.Image): image = image.convert('RGB') else: raise ValueError(f"Unexpected image type: {type(image)}") image_tensor = self.transform(image) label = item[self.label_column] if self.flip_labels: # Convert from Fake=0,Real=1 to Real=0,Fake=1 label = 1 - label return { 'pixel_values': image_tensor, 'labels': torch.tensor(label, dtype=torch.long) } class TextDeepfakeDataset(Dataset): """Dataset for text-based AI-generated content detection. Compatible with artem9k/ai-text-detection-pile format: - text: string content - source: 'human' or 'ai' """ def __init__(self, hf_dataset, tokenizer, max_length=512, text_column='text', label_column='source'): self.dataset = hf_dataset self.tokenizer = tokenizer self.max_length = max_length self.text_column = text_column self.label_column = label_column self.label_map = {'human': 0, 'ai': 1} def __len__(self): return len(self.dataset) def __getitem__(self, idx): item = self.dataset[idx] text = item[self.text_column] # Truncate very long text before tokenization for efficiency if len(text) > 5000: text = text[:5000] encoding = self.tokenizer( text, max_length=self.max_length, padding='max_length', truncation=True, return_tensors='pt' ) label_str = item[self.label_column] label = self.label_map.get(label_str, 0) return { 'input_ids': encoding['input_ids'].squeeze(0), 'attention_mask': encoding['attention_mask'].squeeze(0), 'labels': torch.tensor(label, dtype=torch.long) } class MultimodalDataset(Dataset): """Combined dataset for multimodal training. Interleaves image and text samples, padding the missing modality. """ def __init__(self, image_dataset=None, text_dataset=None): self.image_dataset = image_dataset self.text_dataset = text_dataset self.image_len = len(image_dataset) if image_dataset else 0 self.text_len = len(text_dataset) if text_dataset else 0 self.total_len = self.image_len + self.text_len def __len__(self): return self.total_len def __getitem__(self, idx): if idx < self.image_len: item = self.image_dataset[idx] item['modality'] = 'visual' return item else: item = self.text_dataset[idx - self.image_len] item['modality'] = 'text' return item