| """ |
| Prediction script combining DINOv2 classifier and Qwen2-VL reasoner |
| Outputs predictions.json in required format |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| from torchvision import transforms |
| from transformers import ( |
| AutoImageProcessor, |
| Dinov2Model, |
| Qwen3VLForConditionalGeneration, |
| AutoProcessor |
| ) |
| from peft import PeftModel |
| from PIL import Image |
| import json |
| import os |
| from pathlib import Path |
| from tqdm import tqdm |
| from qwen_vl_utils import process_vision_info |
|
|
| class DINOv2Classifier(nn.Module): |
| def __init__(self, num_classes=3): |
| super().__init__() |
| self.dinov2 = Dinov2Model.from_pretrained("facebook/dinov2-base") |
| |
| |
| self.classifier = nn.Sequential( |
| nn.Linear(768, 512), |
| nn.ReLU(), |
| nn.Dropout(0.3), |
| nn.Linear(512, 256), |
| nn.ReLU(), |
| nn.Dropout(0.3), |
| nn.Linear(256, num_classes) |
| ) |
| |
| def forward(self, pixel_values): |
| outputs = self.dinov2(pixel_values) |
| cls_token = outputs.last_hidden_state[:, 0] |
| logits = self.classifier(cls_token) |
| return logits |
|
|
| class GenAIDetector: |
| def __init__(self, classifier_path): |
| self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| print(f"Using device: {self.device}") |
| |
| |
| print("Loading classifier...") |
| self.classifier = DINOv2Classifier(num_classes=3).to(self.device) |
| checkpoint = torch.load(classifier_path, map_location=self.device) |
| self.classifier.load_state_dict(checkpoint['model_state_dict']) |
| self.classifier.eval() |
| |
| self.image_processor = AutoImageProcessor.from_pretrained("facebook/dinov2-base") |
| |
| |
| print("Loading VLM reasoner...") |
| base_model = Qwen3VLForConditionalGeneration.from_pretrained( |
| "Qwen/Qwen3-VL-8B-Instruct", |
| torch_dtype="auto", |
| device_map="auto" |
| ) |
| self.vlm = base_model |
| self.vlm_processor = AutoProcessor.from_pretrained("Qwen/Qwen3-VL-8B-Instruct") |
| self.vlm.eval() |
| |
| self.class_names = ['real', 'manipulated', 'fake'] |
| self.manipulation_types = { |
| 'real': 'none', |
| 'manipulated': 'inpainting', |
| 'fake': 'full_synthesis' |
| } |
| |
| def classify_image(self, image_path): |
| """Classify image and get confidence scores""" |
| image = Image.open(image_path).convert('RGB') |
| inputs = self.image_processor(images=image, return_tensors="pt") |
| pixel_values = inputs['pixel_values'].to(self.device) |
| |
| with torch.no_grad(): |
| logits = self.classifier(pixel_values) |
| probs = torch.softmax(logits, dim=1) |
| pred_class = torch.argmax(probs, dim=1).item() |
| confidence = probs[0].cpu().numpy() |
| |
| return pred_class, confidence |
| |
| def generate_reasoning(self, image_path, predicted_class): |
| """Generate reasoning using VLM""" |
| class_name = self.class_names[predicted_class] |
| |
| |
| prompt = f"The given image has been flagged as {class_name}. Explain in 2-3 sentences why that might be. Focus on specific features which indicated this." |
| |
| messages = [ |
| { |
| "role": "user", |
| "content": [ |
| {"type": "image", "image": image_path}, |
| {"type": "text", "text": prompt} |
| ] |
| } |
| ] |
| |
| |
| text = self.vlm_processor.apply_chat_template( |
| messages, |
| tokenize=False, |
| add_generation_prompt=True |
| ) |
| |
| |
| image_inputs, video_inputs = process_vision_info(messages) |
| inputs = self.vlm_processor( |
| text=[text], |
| images=image_inputs, |
| videos=video_inputs, |
| padding=True, |
| return_tensors="pt" |
| ) |
| inputs = inputs.to(self.device) |
| |
| |
| with torch.no_grad(): |
| output_ids = self.vlm.generate( |
| **inputs, |
| max_new_tokens=150, |
| temperature=0.7, |
| do_sample=True |
| ) |
| |
| |
| generated_text = self.vlm_processor.batch_decode( |
| output_ids, |
| skip_special_tokens=True, |
| clean_up_tokenization_spaces=False |
| )[0] |
| |
| |
| if "assistant" in generated_text.lower(): |
| reasoning = generated_text.split("assistant")[-1].strip() |
| else: |
| reasoning = generated_text.strip() |
| |
| return reasoning |
| |
| def predict(self, image_path): |
| """Full prediction pipeline""" |
| |
| pred_class, confidence = self.classify_image(image_path) |
| |
| |
| authenticity_score = float(1.0 - confidence[0]) |
| |
| |
| class_name = self.class_names[pred_class] |
| manipulation_type = self.manipulation_types[class_name] |
| |
| |
| reasoning = self.generate_reasoning(image_path, pred_class) |
| |
| return { |
| 'authenticity_score': round(authenticity_score, 2), |
| 'manipulation_type': manipulation_type, |
| 'vlm_reasoning': reasoning |
| } |
|
|
| def main(image_dir, classifier_path, output_file): |
| """Main prediction function""" |
| |
| |
| detector = GenAIDetector(classifier_path) |
| |
| |
| image_extensions = ['.jpg', '.jpeg', '.png'] |
| image_files = [] |
| for ext in image_extensions: |
| image_files.extend(Path(image_dir).glob(f'*{ext}')) |
| image_files.extend(Path(image_dir).glob(f'*{ext.upper()}')) |
| |
| print(f"Found {len(image_files)} images") |
| |
| |
| predictions = [] |
| for image_path in tqdm(image_files, desc="Processing images"): |
| try: |
| result = detector.predict(str(image_path)) |
| result['image_name'] = image_path.name |
| predictions.append(result) |
| except Exception as e: |
| print(f"Error processing {image_path.name}: {str(e)}") |
| continue |
| |
| |
| with open(output_file, 'w') as f: |
| json.dump(predictions, f, indent=2) |
| |
| print(f"\n✓ Processed {len(predictions)} images") |
| print(f"✓ Saved predictions to {output_file}") |
|
|
| if __name__ == "__main__": |
| import argparse |
| |
| parser = argparse.ArgumentParser() |
| parser.add_argument('--image_dir', type=str, default='./test_images', |
| help='Directory containing images to predict') |
| parser.add_argument('--classifier_path', type=str, default='best_model.pth', |
| help='Path to trained DINOv2 checkpoint (.pth file)') |
| parser.add_argument('--output_file', type=str, default='predictions.json', |
| help='Output JSON file') |
| |
| args = parser.parse_args() |
| |
| main(args.image_dir, args.classifier_path, args.output_file) |