Spaces:
Runtime error
Runtime error
| """ | |
| Preprocessing Pipeline for Multimodal Deepfake Detection | |
| ========================================================= | |
| Handles: | |
| - Image preprocessing (resize, normalize, augment) | |
| - Video frame extraction and preprocessing | |
| - Text tokenization and preprocessing | |
| - Dataset loading and formatting | |
| """ | |
| import torch | |
| from torch.utils.data import Dataset, DataLoader | |
| from torchvision import transforms | |
| from PIL import Image | |
| import numpy as np | |
| import io | |
| # ============================================================ | |
| # Image/Video Preprocessing | |
| # ============================================================ | |
| def get_image_transforms(mode='train', image_size=224): | |
| """Get image transforms for training or evaluation. | |
| Based on DeepfakeBench preprocessing pipeline: | |
| - Resize to target size | |
| - Data augmentation for training (flip, color jitter, blur) | |
| - Normalize with ImageNet stats | |
| """ | |
| if mode == 'train': | |
| return transforms.Compose([ | |
| transforms.Resize((image_size, image_size)), | |
| transforms.RandomHorizontalFlip(p=0.5), | |
| transforms.RandomRotation(degrees=10), | |
| transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.1, hue=0.05), | |
| transforms.RandomAffine(degrees=0, translate=(0.05, 0.05)), | |
| transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 1.0)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225]), | |
| transforms.RandomErasing(p=0.1), | |
| ]) | |
| else: | |
| return transforms.Compose([ | |
| transforms.Resize((image_size, image_size)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225]), | |
| ]) | |
| def preprocess_image(image, transform): | |
| """Preprocess a single image (PIL or tensor). | |
| Args: | |
| image: PIL Image or bytes | |
| transform: torchvision transform pipeline | |
| Returns: | |
| tensor: (C, H, W) preprocessed image tensor | |
| """ | |
| if isinstance(image, bytes): | |
| image = Image.open(io.BytesIO(image)) | |
| if isinstance(image, dict) and 'bytes' in image: | |
| image = Image.open(io.BytesIO(image['bytes'])) | |
| if not isinstance(image, Image.Image): | |
| raise ValueError(f"Expected PIL Image, got {type(image)}") | |
| image = image.convert('RGB') | |
| return transform(image) | |
| def extract_video_frames(video_path, num_frames=32, uniform=True): | |
| """Extract frames from video for deepfake detection. | |
| Based on DeepfakeBench: sample 32 frames uniformly. | |
| Args: | |
| video_path: Path to video file | |
| num_frames: Number of frames to extract | |
| uniform: Whether to sample uniformly or randomly | |
| Returns: | |
| frames: list of PIL Images | |
| """ | |
| try: | |
| import cv2 | |
| except ImportError: | |
| raise ImportError("OpenCV required for video processing: pip install opencv-python") | |
| cap = cv2.VideoCapture(video_path) | |
| total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| if total_frames <= 0: | |
| raise ValueError(f"Cannot read video: {video_path}") | |
| if uniform: | |
| indices = np.linspace(0, total_frames - 1, num_frames, dtype=int) | |
| else: | |
| indices = sorted(np.random.choice(total_frames, min(num_frames, total_frames), replace=False)) | |
| frames = [] | |
| for idx in indices: | |
| cap.set(cv2.CAP_PROP_POS_FRAMES, idx) | |
| ret, frame = cap.read() | |
| if ret: | |
| frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| frames.append(Image.fromarray(frame_rgb)) | |
| cap.release() | |
| return frames | |
| # ============================================================ | |
| # Text Preprocessing | |
| # ============================================================ | |
| def get_tokenizer(model_name='roberta-base', max_length=512): | |
| """Get tokenizer for text branch.""" | |
| from transformers import AutoTokenizer | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| return tokenizer, max_length | |
| def preprocess_text(text, tokenizer, max_length=512): | |
| """Tokenize text for the text branch. | |
| Args: | |
| text: input string | |
| tokenizer: HF tokenizer | |
| max_length: maximum sequence length | |
| Returns: | |
| dict with input_ids and attention_mask tensors | |
| """ | |
| encoding = tokenizer( | |
| text, | |
| max_length=max_length, | |
| padding='max_length', | |
| truncation=True, | |
| return_tensors='pt' | |
| ) | |
| return { | |
| 'input_ids': encoding['input_ids'].squeeze(0), | |
| 'attention_mask': encoding['attention_mask'].squeeze(0) | |
| } | |
| # ============================================================ | |
| # Dataset Classes | |
| # ============================================================ | |
| class ImageDeepfakeDataset(Dataset): | |
| """Dataset for image-based deepfake detection. | |
| Compatible with Hemg/deepfake-and-real-images format: | |
| - image: PIL Image | |
| - label: 0=Fake, 1=Real (we flip to 0=Real, 1=Fake for consistency) | |
| """ | |
| def __init__(self, hf_dataset, transform=None, label_column='label', | |
| image_column='image', flip_labels=True): | |
| self.dataset = hf_dataset | |
| self.transform = transform or get_image_transforms('train') | |
| self.label_column = label_column | |
| self.image_column = image_column | |
| self.flip_labels = flip_labels # Hemg dataset: Fake=0, Real=1 | |
| def __len__(self): | |
| return len(self.dataset) | |
| def __getitem__(self, idx): | |
| item = self.dataset[idx] | |
| image = item[self.image_column] | |
| if isinstance(image, dict) and 'bytes' in image: | |
| image = Image.open(io.BytesIO(image['bytes'])) | |
| elif isinstance(image, bytes): | |
| image = Image.open(io.BytesIO(image)) | |
| if isinstance(image, Image.Image): | |
| image = image.convert('RGB') | |
| else: | |
| raise ValueError(f"Unexpected image type: {type(image)}") | |
| image_tensor = self.transform(image) | |
| label = item[self.label_column] | |
| if self.flip_labels: | |
| # Convert from Fake=0,Real=1 to Real=0,Fake=1 | |
| label = 1 - label | |
| return { | |
| 'pixel_values': image_tensor, | |
| 'labels': torch.tensor(label, dtype=torch.long) | |
| } | |
| class TextDeepfakeDataset(Dataset): | |
| """Dataset for text-based AI-generated content detection. | |
| Compatible with artem9k/ai-text-detection-pile format: | |
| - text: string content | |
| - source: 'human' or 'ai' | |
| """ | |
| def __init__(self, hf_dataset, tokenizer, max_length=512, | |
| text_column='text', label_column='source'): | |
| self.dataset = hf_dataset | |
| self.tokenizer = tokenizer | |
| self.max_length = max_length | |
| self.text_column = text_column | |
| self.label_column = label_column | |
| self.label_map = {'human': 0, 'ai': 1} | |
| def __len__(self): | |
| return len(self.dataset) | |
| def __getitem__(self, idx): | |
| item = self.dataset[idx] | |
| text = item[self.text_column] | |
| # Truncate very long text before tokenization for efficiency | |
| if len(text) > 5000: | |
| text = text[:5000] | |
| encoding = self.tokenizer( | |
| text, | |
| max_length=self.max_length, | |
| padding='max_length', | |
| truncation=True, | |
| return_tensors='pt' | |
| ) | |
| label_str = item[self.label_column] | |
| label = self.label_map.get(label_str, 0) | |
| return { | |
| 'input_ids': encoding['input_ids'].squeeze(0), | |
| 'attention_mask': encoding['attention_mask'].squeeze(0), | |
| 'labels': torch.tensor(label, dtype=torch.long) | |
| } | |
| class MultimodalDataset(Dataset): | |
| """Combined dataset for multimodal training. | |
| Interleaves image and text samples, padding the missing modality. | |
| """ | |
| def __init__(self, image_dataset=None, text_dataset=None): | |
| self.image_dataset = image_dataset | |
| self.text_dataset = text_dataset | |
| self.image_len = len(image_dataset) if image_dataset else 0 | |
| self.text_len = len(text_dataset) if text_dataset else 0 | |
| self.total_len = self.image_len + self.text_len | |
| def __len__(self): | |
| return self.total_len | |
| def __getitem__(self, idx): | |
| if idx < self.image_len: | |
| item = self.image_dataset[idx] | |
| item['modality'] = 'visual' | |
| return item | |
| else: | |
| item = self.text_dataset[idx - self.image_len] | |
| item['modality'] = 'text' | |
| return item | |