alianassmaaa's picture
Add Gradio app for multimodal deepfake detection
538f73c verified
"""
Gradio Space pour Multimodal Deepfake Detection
===============================================
Interface web interactive pour:
- Classification d'images (avec GradCAM explicabilité)
- Classification de texte (human vs AI-generated)
- Classification multimodale (image + text)
"""
import gradio as gr
import torch
import numpy as np
from PIL import Image
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from model import MultimodalDeepfakeDetector, GradCAM
from preprocessing import get_image_transforms, get_tokenizer, extract_video_frames
import torch.nn.functional as F
# Global model
MODEL = None
TOKENIZER = None
def load_model_once():
global MODEL, TOKENIZER
if MODEL is None:
try:
from huggingface_hub import hf_hub_download
ckpt_path = hf_hub_download(repo_id="alianassmaaa/multimodal-deepfake-detector", filename="multimodal_ensemble.pt")
except:
ckpt_path = "/app/output/multimodal_ensemble.pt"
checkpoint = torch.load(ckpt_path, map_location='cpu', weights_only=False)
config = checkpoint.get('config', {'text_model_name': 'roberta-base'})
MODEL = MultimodalDeepfakeDetector(visual_pretrained=False, text_model_name=config.get('text_model_name', 'roberta-base'), dropout=0.0)
MODEL.load_state_dict(checkpoint['model_state_dict'])
MODEL.eval()
TOKENIZER, _ = get_tokenizer('roberta-base', 512)
return MODEL, TOKENIZER
def classify_image_gradio(image):
if image is None:
return None, "Veuillez uploader une image."
model, tokenizer = load_model_once()
# Preprocess
transform = get_image_transforms('eval', 224)
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
img_tensor = transform(image.convert('RGB')).unsqueeze(0)
# Inference
with torch.no_grad():
result = model(images=img_tensor, modality='visual')
confidence = result['confidence'].item()
prediction = "🟢 RÉEL (Real)" if confidence < 0.5 else "🔴 FAKE (AI-Generated)"
score = 1 - confidence if confidence < 0.5 else confidence
text_result = f"{prediction}\nConfidence: {score:.2%}\n\nScore brute (P(fake)): {confidence:.4f}"
# GradCAM
gradcam = GradCAM(model.visual_branch, model.visual_branch.get_gradcam_target_layer())
cam = gradcam.generate(img_tensor.clone().requires_grad_(True), class_idx=1)
cam_np = cam.squeeze().numpy()
gradcam.remove_hooks()
# Overlay visualization
img_np = np.array(image.resize((224, 224))) / 255.0
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
axes[0].imshow(img_np)
axes[0].set_title('Image Originale')
axes[0].axis('off')
axes[1].imshow(cam_np, cmap='jet')
axes[1].set_title('GradCAM Heatmap')
axes[1].axis('off')
axes[2].imshow(img_np)
axes[2].imshow(cam_np, cmap='jet', alpha=0.4)
axes[2].set_title('Overlay Explicabilité')
axes[2].axis('off')
plt.suptitle(f'Résultat: {prediction} (confidence: {score:.2%})', fontsize=14, fontweight='bold')
plt.tight_layout()
return fig, text_result
def classify_text_gradio(text):
if not text or len(text.strip()) < 5:
return "Veuillez entrer du texte (minimum 5 caractères)."
model, tokenizer = load_model_once()
encoding = tokenizer(text, max_length=512, padding='max_length', truncation=True, return_tensors='pt')
with torch.no_grad():
result = model(input_ids=encoding['input_ids'], attention_mask=encoding['attention_mask'], modality='text')
confidence = result['confidence'].item()
prediction = "🟢 HUMAIN" if confidence < 0.5 else "🔴 IA-GÉNÉRÉ"
score = 1 - confidence if confidence < 0.5 else confidence
return f"{prediction}\nConfidence: {score:.2%}\n\nScore brute (P(AI)): {confidence:.4f}"
def classify_multimodal_gradio(image, text):
if image is None and (not text or len(text.strip()) < 5):
return "Veuillez fournir au moins une image ou du texte."
model, tokenizer = load_model_once()
images = None
if image is not None:
transform = get_image_transforms('eval', 224)
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
images = transform(image.convert('RGB')).unsqueeze(0)
input_ids = attention_mask = None
if text and len(text.strip()) >= 5:
encoding = tokenizer(text, max_length=512, padding='max_length', truncation=True, return_tensors='pt')
input_ids = encoding['input_ids']
attention_mask = encoding['attention_mask']
with torch.no_grad():
result = model(images=images, input_ids=input_ids, attention_mask=attention_mask, modality='auto')
confidence = result['confidence'].item()
prediction = "🟢 AUTHENTIQUE" if confidence < 0.5 else "🔴 FAKE / IA"
score = 1 - confidence if confidence < 0.5 else confidence
modality_text = ""
if 'visual' in result['modality_scores']:
modality_text += f"Score Visuel (P(fake)): {result['modality_scores']['visual'].item():.4f}\n"
if 'text' in result['modality_scores']:
modality_text += f"Score Texte (P(AI)): {result['modality_scores']['text'].item():.4f}\n"
weights = F.softmax(model.fusion_weights, dim=0)
fusion_info = f"Poids fusion: Visuel={weights[0].item():.3f}, Texte={weights[1].item():.3f}"
return f"{prediction}\nConfidence globale: {score:.2%}\n\n{modality_text}\n{fusion_info}\n\nScore brute: {confidence:.4f}"
# Build Gradio interface
with gr.Blocks(title="Détecteur Multimodal de Deepfakes", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# 🔍 Détecteur Multimodal de Deepfakes
### Classification de contenu: Images, Vidéos (frames), et Texte
Ce modèle combine **EfficientNet-B0** (vision) et **RoBERTa-base** (texte)
pour détecter les contenus générés par IA.
""")
with gr.Tab("📷 Image + GradCAM"):
with gr.Row():
with gr.Column():
img_input = gr.Image(type="numpy", label="Uploader une image")
btn_img = gr.Button("Analyser", variant="primary")
with gr.Column():
img_output_plot = gr.Plot(label="Explicabilité GradCAM")
img_output_text = gr.Textbox(label="Résultat", lines=4)
btn_img.click(classify_image_gradio, inputs=[img_input], outputs=[img_output_plot, img_output_text])
gr.Examples(
examples=[["https://huggingface.co/datasets/Hemg/deepfake-and-real-images/resolve/main/train/0/image.jpg"]],
inputs=[img_input],
label="Exemples (deepfake dataset)"
)
with gr.Tab("📝 Texte"):
with gr.Row():
with gr.Column():
text_input = gr.Textbox(label="Entrer du texte", placeholder="Collez un article, essai, ou paragraphe...", lines=6)
btn_text = gr.Button("Analyser", variant="primary")
with gr.Column():
text_output = gr.Textbox(label="Résultat", lines=6)
btn_text.click(classify_text_gradio, inputs=[text_input], outputs=[text_output])
gr.Examples(
examples=[
["The 2013 film 12 Years a Slave proved that slavery is a worldwide issue. The film made $150 million..."],
["In conclusion, the utilization of advanced machine learning algorithms enables unprecedented optimization of computational workflows."]
],
inputs=[text_input]
)
with gr.Tab("🔄 Multimodal (Image + Texte)"):
with gr.Row():
with gr.Column():
mm_image = gr.Image(type="numpy", label="Image (optionnel)")
mm_text = gr.Textbox(label="Texte (optionnel)", placeholder="Caption ou description...", lines=4)
btn_mm = gr.Button("Analyser", variant="primary")
with gr.Column():
mm_output = gr.Textbox(label="Résultat Fusionné", lines=10)
btn_mm.click(classify_multimodal_gradio, inputs=[mm_image, mm_text], outputs=[mm_output])
gr.Markdown("""
---
**Architecture**: EfficientNet-B0 (vision) + RoBERTa-base (text) + Fusion pondérée + GradCAM explicabilité
**Datasets**: Hemg/deepfake-and-real-images (visuel) | artem9k/ai-text-detection-pile (texte)
[🔗 Voir sur HuggingFace Hub](https://huggingface.co/alianassmaaa/multimodal-deepfake-detector)
""")
if __name__ == "__main__":
demo.launch()