| """ |
| Preprocessing Pipeline for Multimodal Deepfake Detection |
| ========================================================= |
| Handles: |
| - Image preprocessing (resize, normalize, augment) |
| - Video frame extraction and preprocessing |
| - Text tokenization and preprocessing |
| - Dataset loading and formatting |
| """ |
|
|
| import torch |
| from torch.utils.data import Dataset, DataLoader |
| from torchvision import transforms |
| from PIL import Image |
| import numpy as np |
| import io |
|
|
|
|
| def get_image_transforms(mode='train', image_size=224): |
| if mode == 'train': |
| return transforms.Compose([ |
| transforms.Resize((image_size, image_size)), |
| transforms.RandomHorizontalFlip(p=0.5), |
| transforms.RandomRotation(degrees=10), |
| transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.1, hue=0.05), |
| transforms.RandomAffine(degrees=0, translate=(0.05, 0.05)), |
| transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 1.0)), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
| transforms.RandomErasing(p=0.1), |
| ]) |
| else: |
| return transforms.Compose([ |
| transforms.Resize((image_size, image_size)), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
| ]) |
|
|
|
|
| def preprocess_image(image, transform): |
| if isinstance(image, bytes): |
| image = Image.open(io.BytesIO(image)) |
| if isinstance(image, dict) and 'bytes' in image: |
| image = Image.open(io.BytesIO(image['bytes'])) |
| if not isinstance(image, Image.Image): |
| raise ValueError(f"Expected PIL Image, got {type(image)}") |
| image = image.convert('RGB') |
| return transform(image) |
|
|
|
|
| def extract_video_frames(video_path, num_frames=32, uniform=True): |
| try: |
| import cv2 |
| except ImportError: |
| raise ImportError("OpenCV required: pip install opencv-python") |
| cap = cv2.VideoCapture(video_path) |
| total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
| if total_frames <= 0: |
| raise ValueError(f"Cannot read video: {video_path}") |
| if uniform: |
| indices = np.linspace(0, total_frames - 1, num_frames, dtype=int) |
| else: |
| indices = sorted(np.random.choice(total_frames, min(num_frames, total_frames), replace=False)) |
| frames = [] |
| for idx in indices: |
| cap.set(cv2.CAP_PROP_POS_FRAMES, idx) |
| ret, frame = cap.read() |
| if ret: |
| frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
| frames.append(Image.fromarray(frame_rgb)) |
| cap.release() |
| return frames |
|
|
|
|
| def get_tokenizer(model_name='roberta-base', max_length=512): |
| from transformers import AutoTokenizer |
| tokenizer = AutoTokenizer.from_pretrained(model_name) |
| return tokenizer, max_length |
|
|
|
|
| def preprocess_text(text, tokenizer, max_length=512): |
| encoding = tokenizer(text, max_length=max_length, padding='max_length', truncation=True, return_tensors='pt') |
| return {'input_ids': encoding['input_ids'].squeeze(0), 'attention_mask': encoding['attention_mask'].squeeze(0)} |
|
|
|
|
| class ImageDeepfakeDataset(Dataset): |
| def __init__(self, hf_dataset, transform=None, label_column='label', image_column='image', flip_labels=True): |
| self.dataset = hf_dataset |
| self.transform = transform or get_image_transforms('train') |
| self.label_column = label_column |
| self.image_column = image_column |
| self.flip_labels = flip_labels |
|
|
| def __len__(self): |
| return len(self.dataset) |
|
|
| def __getitem__(self, idx): |
| item = self.dataset[idx] |
| image = item[self.image_column] |
| if isinstance(image, dict) and 'bytes' in image: |
| image = Image.open(io.BytesIO(image['bytes'])) |
| elif isinstance(image, bytes): |
| image = Image.open(io.BytesIO(image)) |
| if isinstance(image, Image.Image): |
| image = image.convert('RGB') |
| else: |
| raise ValueError(f"Unexpected image type: {type(image)}") |
| image_tensor = self.transform(image) |
| label = item[self.label_column] |
| if self.flip_labels: |
| label = 1 - label |
| return {'pixel_values': image_tensor, 'labels': torch.tensor(label, dtype=torch.long)} |
|
|
|
|
| class TextDeepfakeDataset(Dataset): |
| def __init__(self, hf_dataset, tokenizer, max_length=512, text_column='text', label_column='source'): |
| self.dataset = hf_dataset |
| self.tokenizer = tokenizer |
| self.max_length = max_length |
| self.text_column = text_column |
| self.label_column = label_column |
| self.label_map = {'human': 0, 'ai': 1} |
|
|
| def __len__(self): |
| return len(self.dataset) |
|
|
| def __getitem__(self, idx): |
| item = self.dataset[idx] |
| text = item[self.text_column] |
| if len(text) > 5000: |
| text = text[:5000] |
| encoding = self.tokenizer(text, max_length=self.max_length, padding='max_length', truncation=True, return_tensors='pt') |
| label_str = item[self.label_column] |
| label = self.label_map.get(label_str, 0) |
| return { |
| 'input_ids': encoding['input_ids'].squeeze(0), |
| 'attention_mask': encoding['attention_mask'].squeeze(0), |
| 'labels': torch.tensor(label, dtype=torch.long) |
| } |
|
|
|
|
| class MultimodalDataset(Dataset): |
| def __init__(self, image_dataset=None, text_dataset=None): |
| self.image_dataset = image_dataset |
| self.text_dataset = text_dataset |
| self.image_len = len(image_dataset) if image_dataset else 0 |
| self.text_len = len(text_dataset) if text_dataset else 0 |
| self.total_len = self.image_len + self.text_len |
|
|
| def __len__(self): |
| return self.total_len |
|
|
| def __getitem__(self, idx): |
| if idx < self.image_len: |
| item = self.image_dataset[idx] |
| item['modality'] = 'visual' |
| return item |
| else: |
| item = self.text_dataset[idx - self.image_len] |
| item['modality'] = 'text' |
| return item |
|
|