alianassmaaa commited on
Commit
e08d744
·
verified ·
1 Parent(s): 24f0d7e

Add preprocessing pipeline

Browse files
Files changed (1) hide show
  1. preprocessing.py +161 -0
preprocessing.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Preprocessing Pipeline for Multimodal Deepfake Detection
3
+ =========================================================
4
+ Handles:
5
+ - Image preprocessing (resize, normalize, augment)
6
+ - Video frame extraction and preprocessing
7
+ - Text tokenization and preprocessing
8
+ - Dataset loading and formatting
9
+ """
10
+
11
+ import torch
12
+ from torch.utils.data import Dataset, DataLoader
13
+ from torchvision import transforms
14
+ from PIL import Image
15
+ import numpy as np
16
+ import io
17
+
18
+
19
+ def get_image_transforms(mode='train', image_size=224):
20
+ if mode == 'train':
21
+ return transforms.Compose([
22
+ transforms.Resize((image_size, image_size)),
23
+ transforms.RandomHorizontalFlip(p=0.5),
24
+ transforms.RandomRotation(degrees=10),
25
+ transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.1, hue=0.05),
26
+ transforms.RandomAffine(degrees=0, translate=(0.05, 0.05)),
27
+ transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 1.0)),
28
+ transforms.ToTensor(),
29
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
30
+ transforms.RandomErasing(p=0.1),
31
+ ])
32
+ else:
33
+ return transforms.Compose([
34
+ transforms.Resize((image_size, image_size)),
35
+ transforms.ToTensor(),
36
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
37
+ ])
38
+
39
+
40
+ def preprocess_image(image, transform):
41
+ if isinstance(image, bytes):
42
+ image = Image.open(io.BytesIO(image))
43
+ if isinstance(image, dict) and 'bytes' in image:
44
+ image = Image.open(io.BytesIO(image['bytes']))
45
+ if not isinstance(image, Image.Image):
46
+ raise ValueError(f"Expected PIL Image, got {type(image)}")
47
+ image = image.convert('RGB')
48
+ return transform(image)
49
+
50
+
51
+ def extract_video_frames(video_path, num_frames=32, uniform=True):
52
+ try:
53
+ import cv2
54
+ except ImportError:
55
+ raise ImportError("OpenCV required: pip install opencv-python")
56
+ cap = cv2.VideoCapture(video_path)
57
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
58
+ if total_frames <= 0:
59
+ raise ValueError(f"Cannot read video: {video_path}")
60
+ if uniform:
61
+ indices = np.linspace(0, total_frames - 1, num_frames, dtype=int)
62
+ else:
63
+ indices = sorted(np.random.choice(total_frames, min(num_frames, total_frames), replace=False))
64
+ frames = []
65
+ for idx in indices:
66
+ cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
67
+ ret, frame = cap.read()
68
+ if ret:
69
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
70
+ frames.append(Image.fromarray(frame_rgb))
71
+ cap.release()
72
+ return frames
73
+
74
+
75
+ def get_tokenizer(model_name='roberta-base', max_length=512):
76
+ from transformers import AutoTokenizer
77
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
78
+ return tokenizer, max_length
79
+
80
+
81
+ def preprocess_text(text, tokenizer, max_length=512):
82
+ encoding = tokenizer(text, max_length=max_length, padding='max_length', truncation=True, return_tensors='pt')
83
+ return {'input_ids': encoding['input_ids'].squeeze(0), 'attention_mask': encoding['attention_mask'].squeeze(0)}
84
+
85
+
86
+ class ImageDeepfakeDataset(Dataset):
87
+ def __init__(self, hf_dataset, transform=None, label_column='label', image_column='image', flip_labels=True):
88
+ self.dataset = hf_dataset
89
+ self.transform = transform or get_image_transforms('train')
90
+ self.label_column = label_column
91
+ self.image_column = image_column
92
+ self.flip_labels = flip_labels
93
+
94
+ def __len__(self):
95
+ return len(self.dataset)
96
+
97
+ def __getitem__(self, idx):
98
+ item = self.dataset[idx]
99
+ image = item[self.image_column]
100
+ if isinstance(image, dict) and 'bytes' in image:
101
+ image = Image.open(io.BytesIO(image['bytes']))
102
+ elif isinstance(image, bytes):
103
+ image = Image.open(io.BytesIO(image))
104
+ if isinstance(image, Image.Image):
105
+ image = image.convert('RGB')
106
+ else:
107
+ raise ValueError(f"Unexpected image type: {type(image)}")
108
+ image_tensor = self.transform(image)
109
+ label = item[self.label_column]
110
+ if self.flip_labels:
111
+ label = 1 - label
112
+ return {'pixel_values': image_tensor, 'labels': torch.tensor(label, dtype=torch.long)}
113
+
114
+
115
+ class TextDeepfakeDataset(Dataset):
116
+ def __init__(self, hf_dataset, tokenizer, max_length=512, text_column='text', label_column='source'):
117
+ self.dataset = hf_dataset
118
+ self.tokenizer = tokenizer
119
+ self.max_length = max_length
120
+ self.text_column = text_column
121
+ self.label_column = label_column
122
+ self.label_map = {'human': 0, 'ai': 1}
123
+
124
+ def __len__(self):
125
+ return len(self.dataset)
126
+
127
+ def __getitem__(self, idx):
128
+ item = self.dataset[idx]
129
+ text = item[self.text_column]
130
+ if len(text) > 5000:
131
+ text = text[:5000]
132
+ encoding = self.tokenizer(text, max_length=self.max_length, padding='max_length', truncation=True, return_tensors='pt')
133
+ label_str = item[self.label_column]
134
+ label = self.label_map.get(label_str, 0)
135
+ return {
136
+ 'input_ids': encoding['input_ids'].squeeze(0),
137
+ 'attention_mask': encoding['attention_mask'].squeeze(0),
138
+ 'labels': torch.tensor(label, dtype=torch.long)
139
+ }
140
+
141
+
142
+ class MultimodalDataset(Dataset):
143
+ def __init__(self, image_dataset=None, text_dataset=None):
144
+ self.image_dataset = image_dataset
145
+ self.text_dataset = text_dataset
146
+ self.image_len = len(image_dataset) if image_dataset else 0
147
+ self.text_len = len(text_dataset) if text_dataset else 0
148
+ self.total_len = self.image_len + self.text_len
149
+
150
+ def __len__(self):
151
+ return self.total_len
152
+
153
+ def __getitem__(self, idx):
154
+ if idx < self.image_len:
155
+ item = self.image_dataset[idx]
156
+ item['modality'] = 'visual'
157
+ return item
158
+ else:
159
+ item = self.text_dataset[idx - self.image_len]
160
+ item['modality'] = 'text'
161
+ return item